Add support for R2DBC

This commit introduces support for R2DBC ("Reactive Relational Database
Connectivity") with custom ConnectionFactory implementations, a
functional DatabaseClient for SQL execution, transaction management, a
bind marker abstraction database initialization utilities, and
exception translation.

Closes gh-25065
This commit is contained in:
Mark Paluch 2020-05-13 15:54:25 +02:00 committed by Sam Brannen
parent 7f79a373c3
commit aff601edf1
114 changed files with 13150 additions and 1 deletions

View File

@ -26,6 +26,7 @@ configure(allprojects) { project ->
mavenBom "com.fasterxml.jackson:jackson-bom:2.11.0"
mavenBom "io.netty:netty-bom:4.1.50.Final"
mavenBom "io.projectreactor:reactor-bom:2020.0.0-SNAPSHOT"
mavenBom "io.r2dbc:r2dbc-bom:Arabba-SR5"
mavenBom "io.rsocket:rsocket-bom:1.0.1"
mavenBom "org.eclipse.jetty:jetty-bom:9.4.30.v20200611"
mavenBom "org.jetbrains.kotlin:kotlin-bom:1.4-M2"

View File

@ -24,6 +24,7 @@ include "spring-expression"
include "spring-instrument"
include "spring-jcl"
include "spring-jdbc"
include "spring-r2dbc"
include "spring-jms"
include "spring-messaging"
include "spring-orm"

View File

@ -0,0 +1,29 @@
description = "Spring R2DBC"
apply plugin: "kotlin"
// Workaround for https://youtrack.jetbrains.com/issue/KT-39610
configurations["optional"].attributes.attribute(
org.gradle.api.attributes.Usage.USAGE_ATTRIBUTE,
objects.named(Usage.class, "java-runtime")
)
dependencies {
compile(project(":spring-beans"))
compile(project(":spring-core"))
compile(project(":spring-tx"))
compile("io.r2dbc:r2dbc-spi")
compile("io.projectreactor:reactor-core")
compileOnly(project(":kotlin-coroutines"))
optional("org.jetbrains.kotlin:kotlin-reflect")
optional("org.jetbrains.kotlin:kotlin-stdlib")
optional("org.jetbrains.kotlinx:kotlinx-coroutines-core")
optional("org.jetbrains.kotlinx:kotlinx-coroutines-reactor")
testCompile(project(":kotlin-coroutines"))
testCompile(testFixtures(project(":spring-beans")))
testCompile(testFixtures(project(":spring-core")))
testCompile(testFixtures(project(":spring-context")))
testCompile("io.projectreactor:reactor-test")
testCompile("io.r2dbc:r2dbc-h2")
testCompile("io.r2dbc:r2dbc-spi-test:0.8.1.RELEASE")
}

View File

@ -0,0 +1,67 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc;
import io.r2dbc.spi.R2dbcException;
import org.springframework.dao.InvalidDataAccessResourceUsageException;
/**
* Exception thrown when SQL specified is invalid. Such exceptions always have a
* {@link io.r2dbc.spi.R2dbcException} root cause.
*
* <p>It would be possible to have subclasses for no such table, no such column etc.
* A custom R2dbcExceptionTranslator could create such more specific exceptions,
* without affecting code using this class.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public class BadSqlGrammarException extends InvalidDataAccessResourceUsageException {
private final String sql;
/**
* Constructor for BadSqlGrammarException.
* @param task name of current task
* @param sql the offending SQL statement
* @param ex the root cause
*/
public BadSqlGrammarException(String task, String sql, R2dbcException ex) {
super(task + "; bad SQL grammar [" + sql + "]", ex);
this.sql = sql;
}
/**
* Return the wrapped {@link R2dbcException}.
*/
public R2dbcException getR2dbcException() {
return (R2dbcException) getCause();
}
/**
* Return the SQL that caused the problem.
*/
public String getSql() {
return this.sql;
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc;
import io.r2dbc.spi.R2dbcException;
import org.springframework.dao.UncategorizedDataAccessException;
import org.springframework.lang.Nullable;
/**
* Exception thrown when we can't classify a {@link R2dbcException} into
* one of our generic data access exceptions.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public class UncategorizedR2dbcException extends UncategorizedDataAccessException {
/** SQL that led to the problem. */
@Nullable
private final String sql;
/**
* Constructor for {@code UncategorizedSQLException}.
* @param msg the detail message
* @param sql the offending SQL statement
* @param ex the exception thrown by underlying data access API
*/
public UncategorizedR2dbcException(String msg, @Nullable String sql, R2dbcException ex) {
super(msg, ex);
this.sql = sql;
}
/**
* Return the wrapped {@link R2dbcException}.
*/
public R2dbcException getR2dbcException() {
return (R2dbcException) getCause();
}
/**
* Return the SQL that led to the problem (if known).
*/
@Nullable
public String getSql() {
return this.sql;
}
}

View File

@ -0,0 +1,442 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.R2dbcBadGrammarException;
import io.r2dbc.spi.R2dbcDataIntegrityViolationException;
import io.r2dbc.spi.R2dbcException;
import io.r2dbc.spi.R2dbcNonTransientException;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import io.r2dbc.spi.R2dbcRollbackException;
import io.r2dbc.spi.R2dbcTimeoutException;
import io.r2dbc.spi.R2dbcTransientException;
import io.r2dbc.spi.R2dbcTransientResourceException;
import io.r2dbc.spi.Wrapped;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;
import org.springframework.core.Ordered;
import org.springframework.dao.ConcurrencyFailureException;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.DataAccessResourceFailureException;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.dao.PermissionDeniedDataAccessException;
import org.springframework.dao.QueryTimeoutException;
import org.springframework.dao.TransientDataAccessResourceException;
import org.springframework.lang.Nullable;
import org.springframework.r2dbc.BadSqlGrammarException;
import org.springframework.r2dbc.UncategorizedR2dbcException;
import org.springframework.transaction.NoTransactionException;
import org.springframework.transaction.reactive.TransactionSynchronization;
import org.springframework.transaction.reactive.TransactionSynchronizationManager;
import org.springframework.util.Assert;
/**
* Helper class that provides static methods for obtaining R2DBC Connections from
* a {@link ConnectionFactory}.
*
* <p>Used internally by Spring's {@code DatabaseClient}, Spring's R2DBC operation
* objects. Can also be used directly in application code.
*
* @author Mark Paluch
* @author Christoph Strobl
* @since 5.3
* @see R2dbcTransactionManager
* @see org.springframework.transaction.reactive.TransactionSynchronizationManager
*/
public abstract class ConnectionFactoryUtils {
/**
* Order value for ReactiveTransactionSynchronization objects that clean up R2DBC Connections.
*/
public static final int CONNECTION_SYNCHRONIZATION_ORDER = 1000;
private static final Log logger = LogFactory.getLog(ConnectionFactoryUtils.class);
private ConnectionFactoryUtils() {}
/**
* Obtain a {@link Connection} from the given {@link ConnectionFactory}.
* Translates exceptions into the Spring hierarchy of unchecked generic
* data access exceptions, simplifying calling code and making any
* exception that is thrown more meaningful.
*
* <p>Is aware of a corresponding Connection bound to the current
* {@link TransactionSynchronizationManager}. Will bind a Connection to the
* {@link TransactionSynchronizationManager} if transaction synchronization is active.
* @param connectionFactory the {@link ConnectionFactory} to obtain
* {@link Connection Connections} from
* @return a R2DBC Connection from the given {@link ConnectionFactory}
* @throws DataAccessResourceFailureException if the attempt to get a
* {@link Connection} failed
* @see #releaseConnection
*/
public static Mono<Connection> getConnection(ConnectionFactory connectionFactory) {
return doGetConnection(connectionFactory)
.onErrorMap(e -> new DataAccessResourceFailureException("Failed to obtain R2DBC Connection", e));
}
/**
* Actually obtain a R2DBC Connection from the given {@link ConnectionFactory}.
* Same as {@link #getConnection}, but preserving the original exceptions.
*
* <p>Is aware of a corresponding Connection bound to the current
* {@link TransactionSynchronizationManager}. Will bind a Connection to the
* {@link TransactionSynchronizationManager} if transaction synchronization is active
* @param connectionFactory the {@link ConnectionFactory} to obtain Connections from
* @return a R2DBC {@link Connection} from the given {@link ConnectionFactory}.
*/
public static Mono<Connection> doGetConnection(ConnectionFactory connectionFactory) {
Assert.notNull(connectionFactory, "ConnectionFactory must not be null");
return TransactionSynchronizationManager.forCurrentTransaction().flatMap(synchronizationManager -> {
ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(connectionFactory);
if (conHolder != null && (conHolder.hasConnection() || conHolder.isSynchronizedWithTransaction())) {
conHolder.requested();
if (!conHolder.hasConnection()) {
if (logger.isDebugEnabled()) {
logger.debug("Fetching resumed R2DBC Connection from ConnectionFactory");
}
return fetchConnection(connectionFactory).doOnNext(conHolder::setConnection);
}
return Mono.just(conHolder.getConnection());
}
// Else we either got no holder or an empty thread-bound holder here.
if (logger.isDebugEnabled()) {
logger.debug("Fetching R2DBC Connection from ConnectionFactory");
}
Mono<Connection> con = fetchConnection(connectionFactory);
if (synchronizationManager.isSynchronizationActive()) {
return con.flatMap(connection -> {
return Mono.just(connection).doOnNext(conn -> {
// Use same Connection for further R2DBC actions within the transaction.
// Thread-bound object will get removed by synchronization at transaction completion.
ConnectionHolder holderToUse = conHolder;
if (holderToUse == null) {
holderToUse = new ConnectionHolder(conn);
}
else {
holderToUse.setConnection(conn);
}
holderToUse.requested();
synchronizationManager
.registerSynchronization(new ConnectionSynchronization(holderToUse, connectionFactory));
holderToUse.setSynchronizedWithTransaction(true);
if (holderToUse != conHolder) {
synchronizationManager.bindResource(connectionFactory, holderToUse);
}
}) // Unexpected exception from external delegation call -> close Connection and rethrow.
.onErrorResume(e -> releaseConnection(connection, connectionFactory).then(Mono.error(e)));
});
}
return con;
}).onErrorResume(NoTransactionException.class, e -> Mono.from(connectionFactory.create()));
}
/**
* Actually fetch a {@link Connection} from the given {@link ConnectionFactory}.
* @param connectionFactory the {@link ConnectionFactory} to obtain
* {@link Connection}s from
* @return a R2DBC {@link Connection} from the given {@link ConnectionFactory}
* (never {@code null}).
* @throws IllegalStateException if the {@link ConnectionFactory} returned a {@code null} value.
* @see ConnectionFactory#create()
*/
private static Mono<Connection> fetchConnection(ConnectionFactory connectionFactory) {
return Mono.from(connectionFactory.create());
}
/**
* Close the given {@link Connection}, obtained from the given {@link ConnectionFactory}, if
* it is not managed externally (that is, not bound to the subscription).
* @param con the {@link Connection} to close if necessary
* @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from
* @see #getConnection
*/
public static Mono<Void> releaseConnection(Connection con, ConnectionFactory connectionFactory) {
return doReleaseConnection(con, connectionFactory)
.onErrorMap(e -> new DataAccessResourceFailureException("Failed to close R2DBC Connection", e));
}
/**
* Actually close the given {@link Connection}, obtained from the given
* {@link ConnectionFactory}. Same as {@link #releaseConnection},
* but preserving the original exception.
* @param connection the {@link Connection} to close if necessary
* @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from
* @see #doGetConnection
*/
public static Mono<Void> doReleaseConnection(Connection connection,
ConnectionFactory connectionFactory) {
return TransactionSynchronizationManager.forCurrentTransaction()
.flatMap(synchronizationManager -> {
ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(connectionFactory);
if (conHolder != null && connectionEquals(conHolder, connection)) {
// It's the transactional Connection: Don't close it.
conHolder.released();
}
return Mono.from(connection.close());
}).onErrorResume(NoTransactionException.class, e -> Mono.from(connection.close()));
}
/**
* Obtain the {@link ConnectionFactory} from the current {@link TransactionSynchronizationManager}.
* @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from
* @see TransactionSynchronizationManager
*/
public static Mono<ConnectionFactory> currentConnectionFactory(ConnectionFactory connectionFactory) {
return TransactionSynchronizationManager.forCurrentTransaction()
.filter(TransactionSynchronizationManager::isSynchronizationActive)
.filter(synchronizationManager -> {
ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(connectionFactory);
return conHolder != null && (conHolder.hasConnection() || conHolder.isSynchronizedWithTransaction());
}).map(synchronizationManager -> connectionFactory);
}
/**
* Translate the given {@link R2dbcException} into a generic {@link DataAccessException}.
* <p>The returned DataAccessException is supposed to contain the original
* {@link R2dbcException} as root cause. However, client code may not generally
* rely on this due to DataAccessExceptions possibly being caused by other resource
* APIs as well. That said, a {@code getRootCause() instanceof R2dbcException}
* check (and subsequent cast) is considered reliable when expecting R2DBC-based
* access to have happened.
* @param task readable text describing the task being attempted
* @param sql the SQL query or update that caused the problem (if known)
* @param ex the offending {@link R2dbcException}
* @return the corresponding DataAccessException instance
*/
public static DataAccessException convertR2dbcException(String task, @Nullable String sql, R2dbcException ex) {
if (ex instanceof R2dbcTransientException) {
if (ex instanceof R2dbcTransientResourceException) {
return new TransientDataAccessResourceException(buildMessage(task, sql, ex), ex);
}
if (ex instanceof R2dbcRollbackException) {
return new ConcurrencyFailureException(buildMessage(task, sql, ex), ex);
}
if (ex instanceof R2dbcTimeoutException) {
return new QueryTimeoutException(buildMessage(task, sql, ex), ex);
}
}
if (ex instanceof R2dbcNonTransientException) {
if (ex instanceof R2dbcNonTransientResourceException) {
return new DataAccessResourceFailureException(buildMessage(task, sql, ex), ex);
}
if (ex instanceof R2dbcDataIntegrityViolationException) {
return new DataIntegrityViolationException(buildMessage(task, sql, ex), ex);
}
if (ex instanceof R2dbcPermissionDeniedException) {
return new PermissionDeniedDataAccessException(buildMessage(task, sql, ex), ex);
}
if (ex instanceof R2dbcBadGrammarException) {
return new BadSqlGrammarException(task, (sql != null ? sql : ""), ex);
}
}
return new UncategorizedR2dbcException(buildMessage(task, sql, ex), sql, ex);
}
/**
* Build a message {@code String} for the given {@link R2dbcException}.
* <p>To be called by translator subclasses when creating an instance of a generic
* {@link org.springframework.dao.DataAccessException} class.
* @param task readable text describing the task being attempted
* @param sql the SQL statement that caused the problem
* @param ex the offending {@code R2dbcException}
* @return the message {@code String} to use
*/
private static String buildMessage(String task, @Nullable String sql, R2dbcException ex) {
return task + "; " + (sql != null ? ("SQL [" + sql + "]; ") : "") + ex.getMessage();
}
/**
* Determine whether the given two {@link Connection}s are equal, asking the target
* {@link Connection} in case of a proxy. Used to detect equality even if the user
* passed in a raw target Connection while the held one is a proxy.
* @param conHolder the {@link ConnectionHolder} for the held {@link Connection} (potentially a proxy)
* @param passedInCon the {@link Connection} passed-in by the user (potentially
* a target {@link Connection} without proxy).
* @return whether the given Connections are equal
* @see #getTargetConnection
*/
private static boolean connectionEquals(ConnectionHolder conHolder, Connection passedInCon) {
if (!conHolder.hasConnection()) {
return false;
}
Connection heldCon = conHolder.getConnection();
// Explicitly check for identity too: for Connection handles that do not implement
// "equals" properly).
return (heldCon == passedInCon || heldCon.equals(passedInCon) || getTargetConnection(heldCon).equals(passedInCon));
}
/**
* Return the innermost target {@link Connection} of the given {@link Connection}.
* If the given {@link Connection} is wrapped, it will be unwrapped until a
* plain {@link Connection} is found. Otherwise, the passed-in Connection
* will be returned as-is.
* @param con the {@link Connection} wrapper to unwrap
* @return the innermost target Connection, or the passed-in one if not wrapped
* @see Wrapped#unwrap()
*/
@SuppressWarnings("unchecked")
public static Connection getTargetConnection(Connection con) {
Connection conToUse = con;
while (conToUse instanceof Wrapped<?>) {
conToUse = ((Wrapped<Connection>) conToUse).unwrap();
}
return conToUse;
}
/**
* Determine the connection synchronization order to use for the given {@link ConnectionFactory}.
* Decreased for every level of nesting that a {@link ConnectionFactory} has,
* checked through the level of {@link DelegatingConnectionFactory} nesting.
* @param connectionFactory the {@link ConnectionFactory} to check
* @return the connection synchronization order to use
* @see #CONNECTION_SYNCHRONIZATION_ORDER
*/
private static int getConnectionSynchronizationOrder(ConnectionFactory connectionFactory) {
int order = CONNECTION_SYNCHRONIZATION_ORDER;
ConnectionFactory current = connectionFactory;
while (current instanceof DelegatingConnectionFactory) {
order--;
current = ((DelegatingConnectionFactory) current).getTargetConnectionFactory();
}
return order;
}
/**
* Callback for resource cleanup at the end of a non-native R2DBC transaction.
*/
private static class ConnectionSynchronization implements TransactionSynchronization, Ordered {
private final ConnectionHolder connectionHolder;
private final ConnectionFactory connectionFactory;
private final int order;
private boolean holderActive = true;
ConnectionSynchronization(ConnectionHolder connectionHolder, ConnectionFactory connectionFactory) {
this.connectionHolder = connectionHolder;
this.connectionFactory = connectionFactory;
this.order = getConnectionSynchronizationOrder(connectionFactory);
}
@Override
public int getOrder() {
return this.order;
}
@Override
public Mono<Void> suspend() {
if (this.holderActive) {
return TransactionSynchronizationManager.forCurrentTransaction()
.flatMap(synchronizationManager -> {
synchronizationManager.unbindResource(this.connectionFactory);
if (this.connectionHolder.hasConnection() && !this.connectionHolder.isOpen()) {
// Release Connection on suspend if the application doesn't keep
// a handle to it anymore. We will fetch a fresh Connection if the
// application accesses the ConnectionHolder again after resume,
// assuming that it will participate in the same transaction.
return releaseConnection(this.connectionHolder.getConnection(), this.connectionFactory)
.doOnTerminate(() -> this.connectionHolder.setConnection(null));
}
return Mono.empty();
});
}
return Mono.empty();
}
@Override
public Mono<Void> resume() {
if (this.holderActive) {
return TransactionSynchronizationManager.forCurrentTransaction()
.doOnNext(synchronizationManager -> synchronizationManager.bindResource(this.connectionFactory, this.connectionHolder))
.then();
}
return Mono.empty();
}
@Override
public Mono<Void> beforeCompletion() {
// Release Connection early if the holder is not open anymore
// (that is, not used by another resource
// that has its own cleanup via transaction synchronization),
// to avoid issues with strict transaction implementations that expect
// the close call before transaction completion.
if (!this.connectionHolder.isOpen()) {
return TransactionSynchronizationManager.forCurrentTransaction()
.flatMap(synchronizationManager -> {
synchronizationManager.unbindResource(this.connectionFactory);
this.holderActive = false;
if (this.connectionHolder.hasConnection()) {
return releaseConnection(this.connectionHolder.getConnection(), this.connectionFactory);
}
return Mono.empty();
});
}
return Mono.empty();
}
@Override
public Mono<Void> afterCompletion(int status) {
// If we haven't closed the Connection in beforeCompletion,
// close it now.
if (this.holderActive) {
// The bound ConnectionHolder might not be available anymore,
// since afterCompletion might get called from a different thread.
return TransactionSynchronizationManager.forCurrentTransaction()
.flatMap(synchronizationManager -> {
synchronizationManager.unbindResourceIfPossible(this.connectionFactory);
this.holderActive = false;
if (this.connectionHolder.hasConnection()) {
return releaseConnection(this.connectionHolder.getConnection(), this.connectionFactory)
// Reset the ConnectionHolder: It might remain bound to the context.
.doOnTerminate(() -> this.connectionHolder.setConnection(null));
}
return Mono.empty();
});
}
this.connectionHolder.reset();
return Mono.empty();
}
}
}

View File

@ -0,0 +1,136 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import org.springframework.lang.Nullable;
import org.springframework.transaction.support.ResourceHolderSupport;
import org.springframework.util.Assert;
/**
* Resource holder wrapping a R2DBC {@link Connection}.
* {@link R2dbcTransactionManager} binds instances of this class to the subscription,
* for a specific {@link ConnectionFactory}.
*
* <p>Inherits rollback-only support for nested R2DBC transactions and reference
* count functionality from the base class.
*
* <p>Note: This is an SPI class, not intended to be used by applications.
*
* @author Mark Paluch
* @author Christoph Strobl
* @since 5.3
* @see R2dbcTransactionManager
* @see ConnectionFactoryUtils
*/
public class ConnectionHolder extends ResourceHolderSupport {
@Nullable
private Connection currentConnection;
private boolean transactionActive;
/**
* Create a new ConnectionHolder for the given R2DBC {@link Connection},
* assuming that there is no ongoing transaction.
* @param connection the R2DBC {@link Connection} to hold
* @see #ConnectionHolder(Connection, boolean)
*/
public ConnectionHolder(Connection connection) {
this(connection, false);
}
/**
* Create a new ConnectionHolder for the given R2DBC {@link Connection}.
* @param connection the R2DBC {@link Connection} to hold
* @param transactionActive whether the given {@link Connection} is involved
* in an ongoing transaction
*/
public ConnectionHolder(Connection connection, boolean transactionActive) {
this.currentConnection = connection;
this.transactionActive = transactionActive;
}
/**
* Return whether this holder currently has a {@link Connection}.
*/
protected boolean hasConnection() {
return (this.currentConnection != null);
}
/**
* Set whether this holder represents an active, R2DBC-managed transaction.
*
* @see R2dbcTransactionManager
*/
protected void setTransactionActive(boolean transactionActive) {
this.transactionActive = transactionActive;
}
/**
* Return whether this holder represents an active, R2DBC-managed transaction.
*/
protected boolean isTransactionActive() {
return this.transactionActive;
}
/**
* Override the existing Connection with the given {@link Connection}.
* <p>Used for releasing the {@link Connection} on suspend
* (with a {@code null} argument) and setting a fresh {@link Connection} on resume.
*/
protected void setConnection(@Nullable Connection connection) {
this.currentConnection = connection;
}
/**
* Return the current {@link Connection} held by this {@link ConnectionHolder}.
* <p>This will be the same {@link Connection} until {@code released} gets called
* on the {@link ConnectionHolder}, which will reset the held {@link Connection},
* fetching a new {@link Connection} on demand.
* @see #released()
*/
public Connection getConnection() {
Assert.notNull(this.currentConnection, "Active Connection is required");
return this.currentConnection;
}
/**
* Releases the current {@link Connection}.
*/
@Override
public void released() {
super.released();
if (!isOpen() && this.currentConnection != null) {
this.currentConnection = null;
}
}
@Override
public void clear() {
super.clear();
this.transactionActive = false;
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import io.r2dbc.spi.Wrapped;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
/**
* R2DBC {@link ConnectionFactory} implementation that delegates
* all calls to a given target {@link ConnectionFactory}.
*
* <p>This class is meant to be subclassed, with subclasses overriding
* only those methods (such as {@link #create()}) that should not simply
* delegate to the target {@link ConnectionFactory}.
*
* @author Mark Paluch
* @since 5.3
* @see #create
*/
public class DelegatingConnectionFactory implements ConnectionFactory, Wrapped<ConnectionFactory> {
private final ConnectionFactory targetConnectionFactory;
public DelegatingConnectionFactory(ConnectionFactory targetConnectionFactory) {
Assert.notNull(targetConnectionFactory, "ConnectionFactory must not be null");
this.targetConnectionFactory = targetConnectionFactory;
}
@Override
public Mono<? extends Connection> create() {
return Mono.from(this.targetConnectionFactory.create());
}
public ConnectionFactory getTargetConnectionFactory() {
return this.targetConnectionFactory;
}
@Override
public ConnectionFactoryMetadata getMetadata() {
return obtainTargetConnectionFactory().getMetadata();
}
@Override
public ConnectionFactory unwrap() {
return obtainTargetConnectionFactory();
}
/**
* Obtain the target {@link ConnectionFactory} for actual use (never {@code null}).
*/
protected ConnectionFactory obtainTargetConnectionFactory() {
return getTargetConnectionFactory();
}
}

View File

@ -0,0 +1,538 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import java.time.Duration;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.IsolationLevel;
import io.r2dbc.spi.R2dbcException;
import io.r2dbc.spi.Result;
import reactor.core.publisher.Mono;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.transaction.CannotCreateTransactionException;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionException;
import org.springframework.transaction.reactive.AbstractReactiveTransactionManager;
import org.springframework.transaction.reactive.GenericReactiveTransaction;
import org.springframework.transaction.reactive.TransactionSynchronizationManager;
import org.springframework.util.Assert;
/**
* {@link org.springframework.transaction.ReactiveTransactionManager}
* implementation for a single R2DBC {@link ConnectionFactory}. This class is
* capable of working in any environment with any R2DBC driver, as long as the
* setup uses a {@link ConnectionFactory} as its {@link Connection} factory
* mechanism. Binds a R2DBC {@link Connection} from the specified
* {@link ConnectionFactory} to the current subscriber context, potentially
* allowing for one context-bound {@link Connection} per {@link ConnectionFactory}.
*
* <p><b>Note: The {@link ConnectionFactory} that this transaction manager
* operates on needs to return independent {@link Connection}s.</b>
* The {@link Connection}s may come from a pool (the typical case), but the
* {@link ConnectionFactory} must not return scoped scoped {@link Connection}s
* or the like. This transaction manager will associate {@link Connection}
* with context-bound transactions itself, according to the specified propagation
* behavior. It assumes that a separate, independent {@link Connection} can
* be obtained even during an ongoing transaction.
*
* <p>Application code is required to retrieve the R2DBC Connection via
* {@link ConnectionFactoryUtils#getConnection(ConnectionFactory)}
* instead of a standard R2DBC-style {@link ConnectionFactory#create()} call.
* Spring classes such as {@code DatabaseClient} use this strategy implicitly.
* If not used in combination with this transaction manager, the
* {@link ConnectionFactoryUtils} lookup strategy behaves exactly like the
* native {@link ConnectionFactory} lookup; it can thus be used in a portable fashion.
*
* <p>Alternatively, you can allow application code to work with the standard
* R2DBC lookup pattern {@link ConnectionFactory#create()}, for example for code
* that is not aware of Spring at all. In that case, define a
* {@link TransactionAwareConnectionFactoryProxy} for your target {@link ConnectionFactory},
* and pass that proxy {@link ConnectionFactory} to your DAOs, which will automatically
* participate in Spring-managed transactions when accessing it.
*
* <p>This transaction manager triggers flush callbacks on registered transaction
* synchronizations (if synchronization is generally active), assuming resources
* operating on the underlying R2DBC {@link Connection}.
*
* @author Mark Paluch
* @since 5.3
* @see ConnectionFactoryUtils#getConnection(ConnectionFactory)
* @see ConnectionFactoryUtils#releaseConnection
* @see TransactionAwareConnectionFactoryProxy
*/
@SuppressWarnings("serial")
public class R2dbcTransactionManager extends AbstractReactiveTransactionManager implements InitializingBean {
private ConnectionFactory connectionFactory;
private boolean enforceReadOnly = false;
/**
* Create a new @link ConnectionFactoryTransactionManager} instance. A ConnectionFactory has to be set to be able to
* use it.
*
* @see #setConnectionFactory
*/
public R2dbcTransactionManager() {}
/**
* Create a new {@link R2dbcTransactionManager} instance.
*
* @param connectionFactory the R2DBC ConnectionFactory to manage transactions for
*/
public R2dbcTransactionManager(ConnectionFactory connectionFactory) {
this();
setConnectionFactory(connectionFactory);
afterPropertiesSet();
}
/**
* Set the R2DBC {@link ConnectionFactory} that this instance should manage transactions for.
* <p>
* This will typically be a locally defined {@link ConnectionFactory}, for example an connection pool.
* <p>
* The {@link ConnectionFactory} specified here should be the target {@link ConnectionFactory} to manage transactions
* for, not a TransactionAwareConnectionFactoryProxy. Only data access code may work with
* TransactionAwareConnectionFactoryProxy, while the transaction manager needs to work on the underlying target
* {@link ConnectionFactory}. If there's nevertheless a TransactionAwareConnectionFactoryProxy passed in, it will be
* unwrapped to extract its target {@link ConnectionFactory}.
* <p>
* <b>The {@link ConnectionFactory} passed in here needs to return independent {@link Connection}s.</b> The
* {@link Connection}s may come from a pool (the typical case), but the {@link ConnectionFactory} must not return
* scoped {@link Connection} or the like.
*
* @see TransactionAwareConnectionFactoryProxy
*/
public void setConnectionFactory(@Nullable ConnectionFactory connectionFactory) {
this.connectionFactory = connectionFactory;
}
/**
* Return the R2DBC {@link ConnectionFactory} that this instance manages transactions for.
*/
@Nullable
public ConnectionFactory getConnectionFactory() {
return this.connectionFactory;
}
/**
* Obtain the {@link ConnectionFactory} for actual use.
*
* @return the {@link ConnectionFactory} (never {@code null})
* @throws IllegalStateException in case of no ConnectionFactory set
*/
protected ConnectionFactory obtainConnectionFactory() {
ConnectionFactory connectionFactory = getConnectionFactory();
Assert.state(connectionFactory != null, "No ConnectionFactory set");
return connectionFactory;
}
/**
* Specify whether to enforce the read-only nature of a transaction (as indicated by
* {@link TransactionDefinition#isReadOnly()} through an explicit statement on the transactional connection: "SET
* TRANSACTION READ ONLY" as understood by Oracle, MySQL and Postgres.
* <p>
* The exact treatment, including any SQL statement executed on the connection, can be customized through through
* {@link #prepareTransactionalConnection}.
*
* @see #prepareTransactionalConnection
*/
public void setEnforceReadOnly(boolean enforceReadOnly) {
this.enforceReadOnly = enforceReadOnly;
}
/**
* Return whether to enforce the read-only nature of a transaction through an explicit statement on the transactional
* connection.
*
* @see #setEnforceReadOnly
*/
public boolean isEnforceReadOnly() {
return this.enforceReadOnly;
}
@Override
public void afterPropertiesSet() {
if (getConnectionFactory() == null) {
throw new IllegalArgumentException("Property 'connectionFactory' is required");
}
}
@Override
protected Object doGetTransaction(TransactionSynchronizationManager synchronizationManager)
throws TransactionException {
ConnectionFactoryTransactionObject txObject = new ConnectionFactoryTransactionObject();
ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(obtainConnectionFactory());
txObject.setConnectionHolder(conHolder, false);
return txObject;
}
@Override
protected boolean isExistingTransaction(Object transaction) {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction;
return (txObject.hasConnectionHolder() && txObject.getConnectionHolder().isTransactionActive());
}
@Override
protected Mono<Void> doBegin(TransactionSynchronizationManager synchronizationManager, Object transaction,
TransactionDefinition definition) throws TransactionException {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction;
return Mono.defer(() -> {
Mono<Connection> connectionMono;
if (!txObject.hasConnectionHolder() || txObject.getConnectionHolder().isSynchronizedWithTransaction()) {
Mono<Connection> newCon = Mono.from(obtainConnectionFactory().create());
connectionMono = newCon.doOnNext(connection -> {
if (logger.isDebugEnabled()) {
logger.debug("Acquired Connection [" + newCon + "] for R2DBC transaction");
}
txObject.setConnectionHolder(new ConnectionHolder(connection), true);
});
}
else {
txObject.getConnectionHolder().setSynchronizedWithTransaction(true);
connectionMono = Mono.just(txObject.getConnectionHolder().getConnection());
}
return connectionMono.flatMap(con -> {
return prepareTransactionalConnection(con, definition, transaction).then(Mono.from(con.beginTransaction()))
.doOnSuccess(v -> {
txObject.getConnectionHolder().setTransactionActive(true);
Duration timeout = determineTimeout(definition);
if (!timeout.isNegative() && !timeout.isZero()) {
txObject.getConnectionHolder().setTimeoutInMillis(timeout.toMillis());
}
// Bind the connection holder to the thread.
if (txObject.isNewConnectionHolder()) {
synchronizationManager.bindResource(obtainConnectionFactory(), txObject.getConnectionHolder());
}
}).thenReturn(con).onErrorResume(e -> {
if (txObject.isNewConnectionHolder()) {
return ConnectionFactoryUtils.releaseConnection(con, obtainConnectionFactory())
.doOnTerminate(() -> txObject.setConnectionHolder(null, false))
.then(Mono.error(e));
}
return Mono.error(e);
});
}).onErrorResume(e -> {
CannotCreateTransactionException ex = new CannotCreateTransactionException(
"Could not open R2DBC Connection for transaction",
e);
return Mono.error(ex);
});
}).then();
}
/**
* Determine the actual timeout to use for the given definition. Will fall back to this manager's default timeout if
* the transaction definition doesn't specify a non-default value.
*
* @param definition the transaction definition
* @return the actual timeout to use
* @see org.springframework.transaction.TransactionDefinition#getTimeout()
*/
protected Duration determineTimeout(TransactionDefinition definition) {
if (definition.getTimeout() != TransactionDefinition.TIMEOUT_DEFAULT) {
return Duration.ofSeconds(definition.getTimeout());
}
return Duration.ZERO;
}
@Override
protected Mono<Object> doSuspend(TransactionSynchronizationManager synchronizationManager, Object transaction)
throws TransactionException {
return Mono.defer(() -> {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction;
txObject.setConnectionHolder(null);
return Mono.justOrEmpty(synchronizationManager.unbindResource(obtainConnectionFactory()));
});
}
@Override
protected Mono<Void> doResume(TransactionSynchronizationManager synchronizationManager, Object transaction,
Object suspendedResources) throws TransactionException {
return Mono.defer(() -> {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction;
txObject.setConnectionHolder(null);
synchronizationManager.bindResource(obtainConnectionFactory(), suspendedResources);
return Mono.empty();
});
}
@Override
protected Mono<Void> doCommit(TransactionSynchronizationManager TransactionSynchronizationManager,
GenericReactiveTransaction status) throws TransactionException {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) status.getTransaction();
Connection connection = txObject.getConnectionHolder().getConnection();
if (status.isDebug()) {
logger.debug("Committing R2DBC transaction on Connection [" + connection + "]");
}
return Mono.from(connection.commitTransaction())
.onErrorMap(R2dbcException.class, ex -> translateException("R2DBC commit", ex));
}
@Override
protected Mono<Void> doRollback(TransactionSynchronizationManager TransactionSynchronizationManager,
GenericReactiveTransaction status) throws TransactionException {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) status.getTransaction();
Connection connection = txObject.getConnectionHolder().getConnection();
if (status.isDebug()) {
logger.debug("Rolling back R2DBC transaction on Connection [" + connection + "]");
}
return Mono.from(connection.rollbackTransaction())
.onErrorMap(R2dbcException.class, ex -> translateException("R2DBC rollback", ex));
}
@Override
protected Mono<Void> doSetRollbackOnly(TransactionSynchronizationManager synchronizationManager,
GenericReactiveTransaction status) throws TransactionException {
return Mono.fromRunnable(() -> {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) status.getTransaction();
if (status.isDebug()) {
logger
.debug("Setting R2DBC transaction [" + txObject.getConnectionHolder().getConnection() + "] rollback-only");
}
txObject.setRollbackOnly();
});
}
@Override
protected Mono<Void> doCleanupAfterCompletion(TransactionSynchronizationManager synchronizationManager,
Object transaction) {
return Mono.defer(() -> {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction;
// Remove the connection holder from the context, if exposed.
if (txObject.isNewConnectionHolder()) {
synchronizationManager.unbindResource(obtainConnectionFactory());
}
// Reset connection.
Connection con = txObject.getConnectionHolder().getConnection();
Mono<Void> afterCleanup = Mono.empty();
if (txObject.isMustRestoreAutoCommit()) {
afterCleanup = afterCleanup.then(Mono.from(con.setAutoCommit(true)));
}
if (txObject.getPreviousIsolationLevel() != null) {
afterCleanup = afterCleanup
.then(Mono.from(con.setTransactionIsolationLevel(txObject.getPreviousIsolationLevel())));
}
return afterCleanup.then(Mono.defer(() -> {
try {
if (txObject.isNewConnectionHolder()) {
if (logger.isDebugEnabled()) {
logger.debug("Releasing R2DBC Connection [" + con + "] after transaction");
}
return ConnectionFactoryUtils.releaseConnection(con, obtainConnectionFactory());
}
}
finally {
txObject.getConnectionHolder().clear();
}
return Mono.empty();
}));
});
}
/**
* Prepare the transactional {@link Connection} right after transaction begin.
* <p>
* The default implementation executes a "SET TRANSACTION READ ONLY" statement if the {@link #setEnforceReadOnly
* "enforceReadOnly"} flag is set to {@code true} and the transaction definition indicates a read-only transaction.
* <p>
* The "SET TRANSACTION READ ONLY" is understood by Oracle, MySQL and Postgres and may work with other databases as
* well. If you'd like to adapt this treatment, override this method accordingly.
*
* @param con the transactional R2DBC Connection
* @param definition the current transaction definition
* @param transaction the transaction object
* @see #setEnforceReadOnly
*/
protected Mono<Void> prepareTransactionalConnection(Connection con, TransactionDefinition definition,
Object transaction) {
ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction;
Mono<Void> prepare = Mono.empty();
if (isEnforceReadOnly() && definition.isReadOnly()) {
prepare = Mono.from(con.createStatement("SET TRANSACTION READ ONLY").execute())
.flatMapMany(Result::getRowsUpdated)
.then();
}
// Apply specific isolation level, if any.
IsolationLevel isolationLevelToUse = resolveIsolationLevel(definition.getIsolationLevel());
if (isolationLevelToUse != null && definition.getIsolationLevel() != TransactionDefinition.ISOLATION_DEFAULT) {
if (logger.isDebugEnabled()) {
logger
.debug("Changing isolation level of R2DBC Connection [" + con + "] to " + isolationLevelToUse.asSql());
}
IsolationLevel currentIsolation = con.getTransactionIsolationLevel();
if (!currentIsolation.asSql().equalsIgnoreCase(isolationLevelToUse.asSql())) {
txObject.setPreviousIsolationLevel(currentIsolation);
prepare = prepare.then(Mono.from(con.setTransactionIsolationLevel(isolationLevelToUse)));
}
}
// Switch to manual commit if necessary. This is very expensive in some R2DBC drivers,
// so we don't want to do it unnecessarily (for example if we've explicitly
// configured the connection pool to set it already).
if (con.isAutoCommit()) {
txObject.setMustRestoreAutoCommit(true);
if (logger.isDebugEnabled()) {
logger.debug("Switching R2DBC Connection [" + con + "] to manual commit");
}
prepare = prepare.then(Mono.from(con.setAutoCommit(false)));
}
return prepare;
}
/**
* Resolve the {@link TransactionDefinition#getIsolationLevel() isolation level constant} to a R2DBC
* {@link IsolationLevel}. If you'd like to extend isolation level translation for vendor-specific
* {@link IsolationLevel}s, override this method accordingly.
*
* @param isolationLevel the isolation level to translate.
* @return the resolved isolation level. Can be {@code null} if not resolvable or the isolation level should remain
* {@link TransactionDefinition#ISOLATION_DEFAULT default}.
* @see TransactionDefinition#getIsolationLevel()
*/
@Nullable
protected IsolationLevel resolveIsolationLevel(int isolationLevel) {
switch (isolationLevel) {
case TransactionDefinition.ISOLATION_READ_COMMITTED:
return IsolationLevel.READ_COMMITTED;
case TransactionDefinition.ISOLATION_READ_UNCOMMITTED:
return IsolationLevel.READ_UNCOMMITTED;
case TransactionDefinition.ISOLATION_REPEATABLE_READ:
return IsolationLevel.REPEATABLE_READ;
case TransactionDefinition.ISOLATION_SERIALIZABLE:
return IsolationLevel.SERIALIZABLE;
}
return null;
}
/**
* Translate the given R2DBC commit/rollback exception to a common Spring exception to propagate from the
* {@link #commit}/{@link #rollback} call.
*
* @param task the task description (commit or rollback).
* @param ex the SQLException thrown from commit/rollback.
* @return the translated exception to emit
*/
protected RuntimeException translateException(String task, R2dbcException ex) {
return ConnectionFactoryUtils.convertR2dbcException(task, null, ex);
}
/**
* ConnectionFactory transaction object, representing a ConnectionHolder. Used as transaction object by
* ConnectionFactoryTransactionManager.
*/
private static class ConnectionFactoryTransactionObject {
@Nullable
private ConnectionHolder connectionHolder;
@Nullable
private IsolationLevel previousIsolationLevel;
private boolean newConnectionHolder;
private boolean mustRestoreAutoCommit;
void setConnectionHolder(@Nullable ConnectionHolder connectionHolder, boolean newConnectionHolder) {
setConnectionHolder(connectionHolder);
this.newConnectionHolder = newConnectionHolder;
}
boolean isNewConnectionHolder() {
return this.newConnectionHolder;
}
void setRollbackOnly() {
getConnectionHolder().setRollbackOnly();
}
public void setConnectionHolder(@Nullable ConnectionHolder connectionHolder) {
this.connectionHolder = connectionHolder;
}
public ConnectionHolder getConnectionHolder() {
Assert.state(this.connectionHolder != null, "No ConnectionHolder available");
return this.connectionHolder;
}
public boolean hasConnectionHolder() {
return (this.connectionHolder != null);
}
public void setPreviousIsolationLevel(@Nullable IsolationLevel previousIsolationLevel) {
this.previousIsolationLevel = previousIsolationLevel;
}
@Nullable
public IsolationLevel getPreviousIsolationLevel() {
return this.previousIsolationLevel;
}
public void setMustRestoreAutoCommit(boolean mustRestoreAutoCommit) {
this.mustRestoreAutoCommit = mustRestoreAutoCommit;
}
public boolean isMustRestoreAutoCommit() {
return this.mustRestoreAutoCommit;
}
}
}

View File

@ -0,0 +1,296 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.concurrent.atomic.AtomicReference;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactories;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import io.r2dbc.spi.Wrapped;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Implementation of {@link DelegatingConnectionFactory} that wraps a
* single R2DBC Connection which is not closed after use.
* Obviously, this is not multi-threading capable.
*
* <p>Note that at shutdown, someone should close the underlying
* Connection via the {@code close()} method. Client code will
* never call close on the Connection handle if it is
* SmartConnectionFactory-aware (e.g. uses
* {@link ConnectionFactoryUtils#releaseConnection(Connection, ConnectionFactory)}).
*
* <p>If client code will call {@link Connection#close()} in the
* assumption of a pooled Connection, like when using persistence tools,
* set "suppressClose" to "true". This will return a close-suppressing
* proxy instead of the physical Connection.
*
* <p>This is primarily intended for testing and pipelining usage of connections.
* For example, it enables easy testing outside an application server, for code
* that expects to work on a {@link ConnectionFactory}.
* Note that this implementation does not act as a connection pool-like utility.
* Connection pooling requires a {@link ConnectionFactory} implemented by e.g.
* {@code r2dbc-pool}.
*
* @author Mark Paluch
* @since 5.3
* @see #create()
* @see Connection#close()
* @see ConnectionFactoryUtils#releaseConnection(Connection, ConnectionFactory)
*/
public class SingleConnectionFactory extends DelegatingConnectionFactory
implements DisposableBean {
/** Create a close-suppressing proxy?. */
private boolean suppressClose;
/** Override auto-commit state?. */
private @Nullable Boolean autoCommit;
/** Wrapped Connection. */
private final AtomicReference<Connection> target = new AtomicReference<>();
/** Proxy Connection. */
private @Nullable Connection connection;
private final Mono<? extends Connection> connectionEmitter;
/**
* Constructor for bean-style configuration.
*/
public SingleConnectionFactory(ConnectionFactory targetConnectionFactory) {
super(targetConnectionFactory);
this.connectionEmitter = super.create().cache();
}
/**
* Create a new {@link SingleConnectionFactory} using a R2DBC connection URL.
*
* @param url the R2DBC URL to use for accessing {@link ConnectionFactory} discovery.
* @param suppressClose if the returned {@link Connection} should be a close-suppressing proxy or the physical
* {@link Connection}.
* @see ConnectionFactories#get(String)
*/
public SingleConnectionFactory(String url, boolean suppressClose) {
super(ConnectionFactories.get(url));
this.suppressClose = suppressClose;
this.connectionEmitter = super.create().cache();
}
/**
* Create a new {@link SingleConnectionFactory} with a given {@link Connection} and
* {@link ConnectionFactoryMetadata}.
*
* @param target underlying target {@link Connection}.
* @param metadata {@link ConnectionFactory} metadata to be associated with this {@link ConnectionFactory}.
* @param suppressClose if the {@link Connection} should be wrapped with a {@link Connection} that suppresses
* {@code close()} calls (to allow for normal {@code close()} usage in applications that expect a pooled
* {@link Connection} but do not know our {@link SmartConnectionFactory} interface).
*/
public SingleConnectionFactory(Connection target, ConnectionFactoryMetadata metadata,
boolean suppressClose) {
super(new ConnectionFactory() {
@Override
public Publisher<? extends Connection> create() {
return Mono.just(target);
}
@Override
public ConnectionFactoryMetadata getMetadata() {
return metadata;
}
});
Assert.notNull(target, "Connection must not be null");
Assert.notNull(metadata, "ConnectionFactoryMetadata must not be null");
this.target.set(target);
this.connectionEmitter = Mono.just(target);
this.suppressClose = suppressClose;
this.connection = (suppressClose ? getCloseSuppressingConnectionProxy(target) : target);
}
/**
* Set whether the returned {@link Connection} should be a close-suppressing proxy or the physical {@link Connection}.
*/
public void setSuppressClose(boolean suppressClose) {
this.suppressClose = suppressClose;
}
/**
* Return whether the returned {@link Connection} will be a close-suppressing proxy or the physical
* {@link Connection}.
*/
protected boolean isSuppressClose() {
return this.suppressClose;
}
/**
* Set whether the returned {@link Connection}'s "autoCommit" setting should be overridden.
*/
public void setAutoCommit(boolean autoCommit) {
this.autoCommit = autoCommit;
}
/**
* Return whether the returned {@link Connection}'s "autoCommit" setting should be overridden.
*
* @return the "autoCommit" value, or {@code null} if none to be applied
*/
@Nullable
protected Boolean getAutoCommitValue() {
return this.autoCommit;
}
@Override
public Mono<? extends Connection> create() {
Connection connection = this.target.get();
return this.connectionEmitter.map(connectionToUse -> {
if (connection == null) {
this.target.compareAndSet(connection, connectionToUse);
this.connection = (isSuppressClose() ? getCloseSuppressingConnectionProxy(connectionToUse) : connectionToUse);
}
return this.connection;
}).flatMap(this::prepareConnection);
}
/**
* Close the underlying {@link Connection}. The provider of this {@link ConnectionFactory} needs to care for proper
* shutdown.
* <p>
* As this bean implements {@link DisposableBean}, a bean factory will automatically invoke this on destruction of its
* cached singletons.
*/
@Override
public void destroy() {
resetConnection().block();
}
/**
* Reset the underlying shared Connection, to be reinitialized on next access.
*/
public Mono<Void> resetConnection() {
Connection connection = this.target.get();
if (connection == null) {
return Mono.empty();
}
return Mono.defer(() -> {
if (this.target.compareAndSet(connection, null)) {
this.connection = null;
return Mono.from(connection.close());
}
return Mono.empty();
});
}
/**
* Prepare the {@link Connection} before using it. Applies {@link #getAutoCommitValue() auto-commit} settings if
* configured.
*
* @param connection the requested {@link Connection}.
* @return the prepared {@link Connection}.
*/
protected Mono<Connection> prepareConnection(Connection connection) {
Boolean autoCommit = getAutoCommitValue();
if (autoCommit != null) {
return Mono.from(connection.setAutoCommit(autoCommit)).thenReturn(connection);
}
return Mono.just(connection);
}
/**
* Wrap the given {@link Connection} with a proxy that delegates every method call to it but suppresses close calls.
*
* @param target the original {@link Connection} to wrap.
* @return the wrapped Connection.
*/
protected Connection getCloseSuppressingConnectionProxy(Connection target) {
return (Connection) Proxy.newProxyInstance(SingleConnectionFactory.class.getClassLoader(),
new Class<?>[] { Connection.class, Wrapped.class }, new CloseSuppressingInvocationHandler(target));
}
/**
* Invocation handler that suppresses close calls on R2DBC Connections.
*
* @see Connection#close()
*/
private static class CloseSuppressingInvocationHandler implements InvocationHandler {
private final Connection target;
CloseSuppressingInvocationHandler(Connection target) {
this.target = target;
}
@Override
@Nullable
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// Invocation on ConnectionProxy interface coming in...
if (method.getName().equals("equals")) {
// Only consider equal when proxies are identical.
return proxy == args[0];
}
else if (method.getName().equals("hashCode")) {
// Use hashCode of PersistenceManager proxy.
return System.identityHashCode(proxy);
}
else if (method.getName().equals("unwrap")) {
return this.target;
}
else if (method.getName().equals("close")) {
// Handle close method: suppress, not valid.
return Mono.empty();
}
// Invoke method on target Connection.
try {
return method.invoke(this.target, args);
}
catch (InvocationTargetException ex) {
throw ex.getTargetException();
}
}
}
}

View File

@ -0,0 +1,198 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Wrapped;
import reactor.core.publisher.Mono;
import org.springframework.lang.Nullable;
import org.springframework.util.ReflectionUtils;
/**
* Proxy for a target R2DBC {@link ConnectionFactory}, adding awareness
* of Spring-managed transactions.
*
* <p>Data access code that should remain unaware of Spring's data access
* support can work with this proxy to seamlessly participate in
* Spring-managed transactions.
* Note that the transaction manager, for example {@link R2dbcTransactionManager},
* still needs to work with the underlying {@link ConnectionFactory},
* <i>not</i> with this proxy.
*
* <p><b>Make sure that {@link TransactionAwareConnectionFactoryProxy} is the outermost
* {@link ConnectionFactory} of a chain of {@link ConnectionFactory} proxies/adapters.</b>
* {@link TransactionAwareConnectionFactoryProxy} can delegate either directly to the
* target connection pool or to some intermediary proxy/adapter.
*
* <p>Delegates to {@link ConnectionFactoryUtils} for automatically participating
* in thread-bound transactions, for example managed by {@link R2dbcTransactionManager}.
* {@link #create()} calls and {@code close} calls on returned {@link Connection}
* will behave properly within a transaction, i.e. always operate on the
* transactional Connection. If not within a transaction, normal {@link ConnectionFactory}
* behavior applies.
*
* <p> This proxy allows data access code to work with the plain R2DBC API. However,
* if possible, use Spring's {@link ConnectionFactoryUtils} or {@code DatabaseClient}
* to get transaction participation even without a proxy for the target
* {@link ConnectionFactory}, avoiding the need to define such a proxy in the first place.
*
* <p><b>NOTE:</b> This {@link ConnectionFactory} proxy needs to return wrapped
* {@link Connection}s (which implement the {@link ConnectionProxy} interface) in order
* to handle close calls properly. Use {@link Wrapped#unwrap()} to retrieve
* the native R2DBC Connection.
*
* @author Mark Paluch
* @author Christoph Strobl
* @since 5.3
* @see ConnectionFactory#create
* @see Connection#close
* @see ConnectionFactoryUtils#doGetConnection
* @see ConnectionFactoryUtils#doReleaseConnection
*/
public class TransactionAwareConnectionFactoryProxy extends DelegatingConnectionFactory {
/**
* Create a new {@link TransactionAwareConnectionFactoryProxy}.
*
* @param targetConnectionFactory the target {@link ConnectionFactory}.
* @throws IllegalArgumentException if given {@link ConnectionFactory} is {@code null}.
*/
public TransactionAwareConnectionFactoryProxy(ConnectionFactory targetConnectionFactory) {
super(targetConnectionFactory);
}
/**
* Delegates to {@link ConnectionFactoryUtils} for automatically participating in Spring-managed transactions.
* <p>
* The returned {@link ConnectionFactory} handle implements the {@link ConnectionProxy} interface, allowing to
* retrieve the underlying target {@link Connection}.
*
* @return a transactional {@link Connection} if any, a new one else.
* @see ConnectionFactoryUtils#doGetConnection
* @see ConnectionProxy#getTargetConnection
*/
@Override
public Mono<Connection> create() {
return getTransactionAwareConnectionProxy(obtainTargetConnectionFactory());
}
/**
* Wraps the given {@link Connection} with a proxy that delegates every method call to it but delegates
* {@code close()} calls to {@link ConnectionFactoryUtils}.
*
* @param targetConnectionFactory the {@link ConnectionFactory} that the {@link Connection} came from.
* @return the wrapped {@link Connection}.
* @see Connection#close()
* @see ConnectionFactoryUtils#doReleaseConnection
*/
protected Mono<Connection> getTransactionAwareConnectionProxy(ConnectionFactory targetConnectionFactory) {
return ConnectionFactoryUtils.getConnection(targetConnectionFactory)
.map(connection -> proxyConnection(connection, targetConnectionFactory));
}
private static Connection proxyConnection(Connection connection, ConnectionFactory targetConnectionFactory) {
return (Connection) Proxy.newProxyInstance(TransactionAwareConnectionFactoryProxy.class.getClassLoader(),
new Class<?>[] { Connection.class, Wrapped.class },
new TransactionAwareInvocationHandler(connection, targetConnectionFactory));
}
/**
* Invocation handler that delegates close calls on R2DBC Connections to {@link ConnectionFactoryUtils} for being
* aware of context-bound transactions.
*/
private static class TransactionAwareInvocationHandler implements InvocationHandler {
private final Connection connection;
private final ConnectionFactory targetConnectionFactory;
private boolean closed = false;
TransactionAwareInvocationHandler(Connection connection, ConnectionFactory targetConnectionFactory) {
this.connection = connection;
this.targetConnectionFactory = targetConnectionFactory;
}
@Override
@Nullable
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (ReflectionUtils.isObjectMethod(method)) {
if (ReflectionUtils.isToStringMethod(method)) {
return proxyToString(proxy);
}
if (ReflectionUtils.isEqualsMethod(method)) {
return (proxy == args[0]);
}
if (ReflectionUtils.isHashCodeMethod(method)) {
return System.identityHashCode(proxy);
}
}
// Invocation on ConnectionProxy interface coming in...
switch (method.getName()) {
case "unwrap":
return this.connection;
case "close":
// Handle close method: only close if not within a transaction.
return ConnectionFactoryUtils.doReleaseConnection(this.connection, this.targetConnectionFactory)
.doOnSubscribe(n -> this.closed = true);
case "isClosed":
return this.closed;
}
if (this.closed) {
throw new IllegalStateException("Connection handle already closed");
}
// Invoke method on target Connection.
try {
return method.invoke(this.connection, args);
}
catch (InvocationTargetException ex) {
throw ex.getTargetException();
}
}
private String proxyToString(@Nullable Object proxy) {
// Allow for differentiating between the proxy and the raw Connection.
StringBuilder sb = new StringBuilder("Transaction-aware proxy for target Connection ");
if (this.connection != null) {
sb.append("[").append(this.connection.toString()).append("]");
}
else {
sb.append(" from ConnectionFactory [").append(this.targetConnectionFactory).append("]");
}
return sb.toString();
}
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import org.springframework.core.io.support.EncodedResource;
/**
* Thrown by {@link ScriptUtils} if an SQL script cannot be read.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public class CannotReadScriptException extends ScriptException {
/**
* Create a new {@code CannotReadScriptException}.
* @param resource the resource that cannot be read from.
* @param cause the underlying cause of the resource access failure.
*/
public CannotReadScriptException(EncodedResource resource, Throwable cause) {
super("Cannot read SQL script from " + resource, cause);
}
}

View File

@ -0,0 +1,92 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import io.r2dbc.spi.Connection;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
/**
* Composite {@link DatabasePopulator} that delegates to a list of given
* {@link DatabasePopulator} implementations, executing all scripts.
*
* @author Mark Paluch
* @since 5.3
*/
public class CompositeDatabasePopulator implements DatabasePopulator {
private final List<DatabasePopulator> populators = new ArrayList<>(4);
/**
* Create an empty {@code CompositeDatabasePopulator}.
* @see #setPopulators
* @see #addPopulators
*/
public CompositeDatabasePopulator() {}
/**
* Create a {@code CompositeDatabasePopulator}. with the given populators.
* @param populators one or more populators to delegate to.
*/
public CompositeDatabasePopulator(Collection<DatabasePopulator> populators) {
Assert.notNull(populators, "Collection of DatabasePopulator must not be null");
this.populators.addAll(populators);
}
/**
* Create a {@code CompositeDatabasePopulator} with the given populators.
* @param populators one or more populators to delegate to.
*/
public CompositeDatabasePopulator(DatabasePopulator... populators) {
Assert.notNull(populators, "DatabasePopulators must not be null");
this.populators.addAll(Arrays.asList(populators));
}
/**
* Specify one or more populators to delegate to.
*/
public void setPopulators(DatabasePopulator... populators) {
Assert.notNull(populators, "DatabasePopulators must not be null");
this.populators.clear();
this.populators.addAll(Arrays.asList(populators));
}
/**
* Add one or more populators to the list of delegates.
*/
public void addPopulators(DatabasePopulator... populators) {
Assert.notNull(populators, "DatabasePopulators must not be null");
this.populators.addAll(Arrays.asList(populators));
}
@Override
public Mono<Void> populate(Connection connection) throws ScriptException {
Assert.notNull(connection, "Connection must not be null");
return Flux.fromIterable(this.populators).concatMap(populator -> populator.populate(connection))
.then();
}
}

View File

@ -0,0 +1,116 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import io.r2dbc.spi.ConnectionFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Used to {@linkplain #setDatabasePopulator set up} a database during
* initialization and {@link #setDatabaseCleaner clean up} a database during
* destruction.
*
* @author Mark Paluch
* @since 5.3
* @see DatabasePopulator
*/
public class ConnectionFactoryInitializer implements InitializingBean, DisposableBean {
@Nullable
private ConnectionFactory connectionFactory;
@Nullable
private DatabasePopulator databasePopulator;
@Nullable
private DatabasePopulator databaseCleaner;
private boolean enabled = true;
/**
* The {@link ConnectionFactory} for the database to populate when this component is initialized and to clean up when
* this component is shut down.
* <p/>
* This property is mandatory with no default provided.
*
* @param connectionFactory the R2DBC {@link ConnectionFactory}.
*/
public void setConnectionFactory(ConnectionFactory connectionFactory) {
this.connectionFactory = connectionFactory;
}
/**
* Set the {@link DatabasePopulator} to execute during the bean initialization phase.
*
* @param databasePopulator the {@link DatabasePopulator} to use during initialization
* @see #setDatabaseCleaner
*/
public void setDatabasePopulator(DatabasePopulator databasePopulator) {
this.databasePopulator = databasePopulator;
}
/**
* Set the {@link DatabasePopulator} to execute during the bean destruction phase, cleaning up the database and
* leaving it in a known state for others.
*
* @param databaseCleaner the {@link DatabasePopulator} to use during destruction
* @see #setDatabasePopulator
*/
public void setDatabaseCleaner(DatabasePopulator databaseCleaner) {
this.databaseCleaner = databaseCleaner;
}
/**
* Flag to explicitly enable or disable the {@link #setDatabasePopulator database populator} and
* {@link #setDatabaseCleaner database cleaner}.
*
* @param enabled {@code true} if the database populator and database cleaner should be called on startup and
* shutdown, respectively
*/
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
/**
* Use the {@link #setDatabasePopulator database populator} to set up the database.
*/
@Override
public void afterPropertiesSet() {
execute(this.databasePopulator);
}
/**
* Use the {@link #setDatabaseCleaner database cleaner} to clean up the database.
*/
@Override
public void destroy() {
execute(this.databaseCleaner);
}
private void execute(@Nullable DatabasePopulator populator) {
Assert.state(this.connectionFactory != null, "ConnectionFactory must be set");
if (this.enabled && populator != null) {
populator.populate(this.connectionFactory).block();
}
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import reactor.core.publisher.Mono;
import org.springframework.dao.DataAccessException;
import org.springframework.r2dbc.connection.ConnectionFactoryUtils;
import org.springframework.util.Assert;
/**
* Strategy used to populate, initialize, or clean up a database.
*
* @author Mark Paluch
* @since 5.3
* @see ResourceDatabasePopulator
* @see ConnectionFactoryInitializer
*/
@FunctionalInterface
public interface DatabasePopulator {
/**
* Populate, initialize, or clean up the database using the
* provided R2DBC {@link Connection}.
*
* @param connection the R2DBC connection to use to populate the db;
* already configured and ready to use, must not be {@code null}
* @return {@link Mono} that initiates script execution and is
* notified upon completion
* @throws ScriptException in all other error cases
*/
Mono<Void> populate(Connection connection) throws ScriptException;
/**
* Execute the given {@link DatabasePopulator} against the given {@link ConnectionFactory}.
* @param connectionFactory the {@link ConnectionFactory} to execute against
* @return {@link Mono} that initiates {@link DatabasePopulator#populate(Connection)}
* and is notified upon completion
*/
default Mono<Void> populate(ConnectionFactory connectionFactory)
throws DataAccessException {
Assert.notNull(connectionFactory, "ConnectionFactory must not be null");
return Mono.usingWhen(ConnectionFactoryUtils.getConnection(connectionFactory), //
this::populate, //
connection -> ConnectionFactoryUtils.releaseConnection(connection, connectionFactory), //
(connection, err) -> ConnectionFactoryUtils.releaseConnection(connection, connectionFactory),
connection -> ConnectionFactoryUtils.releaseConnection(connection, connectionFactory))
.onErrorMap(ex -> !(ex instanceof ScriptException),
ex -> new UncategorizedScriptException("Failed to execute database script", ex));
}
}

View File

@ -0,0 +1,273 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import io.r2dbc.spi.Connection;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.io.support.EncodedResource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Populates, initializes, or cleans up a database using SQL
* scripts defined in external resources.
* <ul>
* <li>Call {@link #addScript} to add a single SQL script location.
* <li>Call {@link #addScripts} to add multiple SQL script locations.
* <li>Consult the setter methods in this class for further configuration options.
* <li>Call {@link #populate} to initialize or clean up the database using the configured scripts.
* </ul>
*
* @author Keith Donald
* @author Dave Syer
* @author Juergen Hoeller
* @author Chris Beams
* @author Oliver Gierke
* @author Sam Brannen
* @author Chris Baldwin
* @author Phillip Webb
* @author Mark Paluch
* @since 5.3
* @see ScriptUtils
*/
public class ResourceDatabasePopulator implements DatabasePopulator {
List<Resource> scripts = new ArrayList<>();
@Nullable
private Charset sqlScriptEncoding;
private String separator = ScriptUtils.DEFAULT_STATEMENT_SEPARATOR;
private String[] commentPrefixes = ScriptUtils.DEFAULT_COMMENT_PREFIXES;
private String blockCommentStartDelimiter = ScriptUtils.DEFAULT_BLOCK_COMMENT_START_DELIMITER;
private String blockCommentEndDelimiter = ScriptUtils.DEFAULT_BLOCK_COMMENT_END_DELIMITER;
private boolean continueOnError = false;
private boolean ignoreFailedDrops = false;
private DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory();
/**
* Create a new {@code ResourceDatabasePopulator} with default settings.
*/
public ResourceDatabasePopulator() {
}
/**
* Create a new {@code ResourceDatabasePopulator} with default settings for the supplied scripts.
* @param scripts the scripts to execute to initialize or clean up the database (never {@code null})
*/
public ResourceDatabasePopulator(Resource... scripts) {
setScripts(scripts);
}
/**
* Construct a new {@code ResourceDatabasePopulator} with the supplied values.
* @param continueOnError flag to indicate that all failures in SQL should be
* logged but not cause a failure
* @param ignoreFailedDrops flag to indicate that a failed SQL {@code DROP}
* statement can be ignored
* @param sqlScriptEncoding the encoding for the supplied SQL scripts
* (may be {@code null} or <em>empty</em> to indicate platform encoding)
* @param scripts the scripts to execute to initialize or clean up the database
* (never {@code null})
*/
public ResourceDatabasePopulator(boolean continueOnError, boolean ignoreFailedDrops,
@Nullable String sqlScriptEncoding, Resource... scripts) {
this.continueOnError = continueOnError;
this.ignoreFailedDrops = ignoreFailedDrops;
setSqlScriptEncoding(sqlScriptEncoding);
setScripts(scripts);
}
/**
* Add a script to execute to initialize or clean up the database.
* @param script the path to an SQL script (never {@code null})
*/
public void addScript(Resource script) {
Assert.notNull(script, "'script' must not be null");
this.scripts.add(script);
}
/**
* Add multiple scripts to execute to initialize or clean up the database.
* @param scripts the scripts to execute (never {@code null})
*/
public void addScripts(Resource... scripts) {
assertContentsOfScriptArray(scripts);
this.scripts.addAll(Arrays.asList(scripts));
}
/**
* Set the scripts to execute to initialize or clean up the database,
* replacing any previously added scripts.
* @param scripts the scripts to execute (never {@code null})
*/
public void setScripts(Resource... scripts) {
assertContentsOfScriptArray(scripts);
// Ensure that the list is modifiable
this.scripts = new ArrayList<>(Arrays.asList(scripts));
}
private void assertContentsOfScriptArray(Resource... scripts) {
Assert.notNull(scripts, "'scripts' must not be null");
Assert.noNullElements(scripts, "'scripts' must not contain null elements");
}
/**
* Specify the encoding for the configured SQL scripts,
* if different from the platform encoding.
* @param sqlScriptEncoding the encoding used in scripts
* (may be {@code null} or empty to indicate platform encoding)
* @see #addScript(Resource)
*/
public void setSqlScriptEncoding(@Nullable String sqlScriptEncoding) {
setSqlScriptEncoding(StringUtils.hasText(sqlScriptEncoding) ? Charset.forName(sqlScriptEncoding) : null);
}
/**
* Specify the encoding for the configured SQL scripts,
* if different from the platform encoding.
* @param sqlScriptEncoding the encoding used in scripts
* (may be {@code null} or empty to indicate platform encoding)
* @see #addScript(Resource)
*/
public void setSqlScriptEncoding(@Nullable Charset sqlScriptEncoding) {
this.sqlScriptEncoding = sqlScriptEncoding;
}
/**
* Specify the statement separator, if a custom one.
* <p>Defaults to {@code ";"} if not specified and falls back to {@code "\n"}
* as a last resort; may be set to {@link ScriptUtils#EOF_STATEMENT_SEPARATOR}
* to signal that each script contains a single statement without a separator.
* @param separator the script statement separator
*/
public void setSeparator(String separator) {
this.separator = separator;
}
/**
* Set the prefix that identifies single-line comments within the SQL scripts.
* <p>Defaults to {@code "--"}.
* @param commentPrefix the prefix for single-line comments
* @see #setCommentPrefixes(String...)
*/
public void setCommentPrefix(String commentPrefix) {
Assert.hasText(commentPrefix, "'commentPrefix' must not be null or empty");
this.commentPrefixes = new String[] { commentPrefix };
}
/**
* Set the prefixes that identify single-line comments within the SQL scripts.
* <p>Defaults to {@code ["--"]}.
* @param commentPrefixes the prefixes for single-line comments
*/
public void setCommentPrefixes(String... commentPrefixes) {
Assert.notEmpty(commentPrefixes, "'commentPrefixes' must not be null or empty");
Assert.noNullElements(commentPrefixes, "'commentPrefixes' must not contain null elements");
this.commentPrefixes = commentPrefixes;
}
/**
* Set the start delimiter that identifies block comments within the SQL
* scripts.
* <p>Defaults to {@code "/*"}.
* @param blockCommentStartDelimiter the start delimiter for block comments
* (never {@code null} or empty)
* @see #setBlockCommentEndDelimiter
*/
public void setBlockCommentStartDelimiter(String blockCommentStartDelimiter) {
Assert.hasText(blockCommentStartDelimiter, "'blockCommentStartDelimiter' must not be null or empty");
this.blockCommentStartDelimiter = blockCommentStartDelimiter;
}
/**
* Set the end delimiter that identifies block comments within the SQL
* scripts.
* <p>Defaults to <code>"*&#47;"</code>.
* @param blockCommentEndDelimiter the end delimiter for block comments
* (never {@code null} or empty)
* @see #setBlockCommentStartDelimiter
*/
public void setBlockCommentEndDelimiter(String blockCommentEndDelimiter) {
Assert.hasText(blockCommentEndDelimiter, "'blockCommentEndDelimiter' must not be null or empty");
this.blockCommentEndDelimiter = blockCommentEndDelimiter;
}
/**
* Flag to indicate that all failures in SQL should be logged but not cause a failure.
* <p>Defaults to {@code false}.
* @param continueOnError {@code true} if script execution should continue on error
*/
public void setContinueOnError(boolean continueOnError) {
this.continueOnError = continueOnError;
}
/**
* Flag to indicate that a failed SQL {@code DROP} statement can be ignored.
* <p>This is useful for a non-embedded database whose SQL dialect does not
* support an {@code IF EXISTS} clause in a {@code DROP} statement.
* <p>The default is {@code false} so that if the populator runs accidentally, it will
* fail fast if a script starts with a {@code DROP} statement.
* @param ignoreFailedDrops {@code true} if failed drop statements should be ignored
*/
public void setIgnoreFailedDrops(boolean ignoreFailedDrops) {
this.ignoreFailedDrops = ignoreFailedDrops;
}
/**
* Set the {@link DataBufferFactory} to use for {@link Resource} loading.
* <p>Defaults to {@link DefaultDataBufferFactory}.
* @param dataBufferFactory the {@link DataBufferFactory} to use, must not be {@code null}
*/
public void setDataBufferFactory(DataBufferFactory dataBufferFactory) {
Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null");
this.dataBufferFactory = dataBufferFactory;
}
@Override
public Mono<Void> populate(Connection connection) throws ScriptException {
Assert.notNull(connection, "Connection must not be null");
return Flux.fromIterable(this.scripts).concatMap(resource -> {
EncodedResource encodedScript = new EncodedResource(resource, this.sqlScriptEncoding);
return ScriptUtils.executeSqlScript(connection, encodedScript, this.dataBufferFactory, this.continueOnError,
this.ignoreFailedDrops, this.commentPrefixes, this.separator, this.blockCommentStartDelimiter,
this.blockCommentEndDelimiter);
}).then();
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import org.springframework.dao.DataAccessException;
import org.springframework.lang.Nullable;
/**
* Root of the hierarchy of data access exceptions that are related to processing of SQL scripts.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public abstract class ScriptException extends DataAccessException {
/**
* Create a new {@code ScriptException}.
* @param message the detail message
*/
public ScriptException(String message) {
super(message);
}
/**
* Create a new {@code ScriptException}.
* @param message the detail message
* @param cause the root cause
*/
public ScriptException(String message, @Nullable Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import org.springframework.core.io.support.EncodedResource;
import org.springframework.lang.Nullable;
/**
* Thrown by {@link ScriptUtils} if an SQL script cannot be properly parsed.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public class ScriptParseException extends ScriptException {
/**
* Create a new {@code ScriptParseException}.
* @param message detailed message
* @param resource the resource from which the SQL script was read
*/
public ScriptParseException(String message, @Nullable EncodedResource resource) {
super(buildMessage(message, resource));
}
/**
* Create a new {@code ScriptParseException}.
* @param message detailed message
* @param resource the resource from which the SQL script was read
* @param cause the underlying cause of the failure
*/
public ScriptParseException(String message, @Nullable EncodedResource resource, @Nullable Throwable cause) {
super(buildMessage(message, resource), cause);
}
private static String buildMessage(String message, @Nullable EncodedResource resource) {
return String.format("Failed to parse SQL script from resource [%s]: %s",
(resource == null ? "<unknown>" : resource), message);
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import org.springframework.core.io.support.EncodedResource;
/**
* Thrown by {@link ScriptUtils} if a statement in an SQL script failed when
* executing it against the target database.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public class ScriptStatementFailedException extends ScriptException {
/**
* Create a new {@code ScriptStatementFailedException}.
* @param stmt the actual SQL statement that failed
* @param stmtNumber the statement number in the SQL script (i.e.,
* the n<sup>th</sup> statement present in the resource)
* @param encodedResource the resource from which the SQL statement was read
* @param cause the underlying cause of the failure
*/
public ScriptStatementFailedException(String stmt, int stmtNumber, EncodedResource encodedResource, Throwable cause) {
super(buildErrorMessage(stmt, stmtNumber, encodedResource), cause);
}
/**
* Build an error message for an SQL script execution failure,
* based on the supplied arguments.
* @param stmt the actual SQL statement that failed
* @param stmtNumber the statement number in the SQL script (i.e.,
* the n<sup>th</sup> statement present in the resource)
* @param encodedResource the resource from which the SQL statement was read
* @return an error message suitable for an exception's <em>detail message</em>
* or logging
*/
public static String buildErrorMessage(String stmt, int stmtNumber, EncodedResource encodedResource) {
return String.format("Failed to execute SQL script statement #%s of %s: %s", stmtNumber, encodedResource, stmt);
}
}

View File

@ -0,0 +1,666 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.LineNumberReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Result;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.io.support.EncodedResource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Generic utility methods for working with SQL scripts.
* <p>Mainly for internal use within the framework.
*
* @author Thomas Risberg
* @author Sam Brannen
* @author Juergen Hoeller
* @author Keith Donald
* @author Dave Syer
* @author Chris Beams
* @author Oliver Gierke
* @author Chris Baldwin
* @author Nicolas Debeissat
* @author Phillip Webb
* @author Mark Paluch
* @since 5.3
*/
public abstract class ScriptUtils {
/**
* Default statement separator within SQL scripts: {@code ";"}.
*/
public static final String DEFAULT_STATEMENT_SEPARATOR = ";";
/**
* Fallback statement separator within SQL scripts: {@code "\n"}.
* <p>Used if neither a custom separator nor the
* {@link #DEFAULT_STATEMENT_SEPARATOR} is present in a given script.
*/
public static final String FALLBACK_STATEMENT_SEPARATOR = "\n";
/**
* End of file (EOF) SQL statement separator: {@code "^^^ END OF SCRIPT ^^^"}.
* <p>This value may be supplied as the {@code separator} to {@link
* #executeSqlScript(Connection, EncodedResource, DataBufferFactory, boolean, boolean, String[], String, String, String)}
* to denote that an SQL script contains a single statement (potentially
* spanning multiple lines) with no explicit statement separator. Note that
* such a script should not actually contain this value; it is merely a
* <em>virtual</em> statement separator.
*/
public static final String EOF_STATEMENT_SEPARATOR = "^^^ END OF SCRIPT ^^^";
/**
* Default prefix for single-line comments within SQL scripts: {@code "--"}.
*/
public static final String DEFAULT_COMMENT_PREFIX = "--";
/**
* Default prefixes for single-line comments within SQL scripts: {@code ["--"]}.
*/
public static final String[] DEFAULT_COMMENT_PREFIXES = {DEFAULT_COMMENT_PREFIX};
/**
* Default start delimiter for block comments within SQL scripts: {@code "/*"}.
*/
public static final String DEFAULT_BLOCK_COMMENT_START_DELIMITER = "/*";
/**
* Default end delimiter for block comments within SQL scripts: <code>"*&#47;"</code>.
*/
public static final String DEFAULT_BLOCK_COMMENT_END_DELIMITER = "*/";
private static final Log logger = LogFactory.getLog(ScriptUtils.class);
// utility constructor
private ScriptUtils() {}
/**
* Split an SQL script into separate statements delimited by the provided
* separator character. Each individual statement will be added to the
* provided {@code List}.
* <p>Within the script, {@value #DEFAULT_COMMENT_PREFIX} will be used as the
* comment prefix; any text beginning with the comment prefix and extending to
* the end of the line will be omitted from the output. Similarly,
* {@value #DEFAULT_BLOCK_COMMENT_START_DELIMITER} and
* {@value #DEFAULT_BLOCK_COMMENT_END_DELIMITER} will be used as the
* <em>start</em> and <em>end</em> block comment delimiters: any text enclosed
* in a block comment will be omitted from the output. In addition, multiple
* adjacent whitespace characters will be collapsed into a single space.
* @param script the SQL script
* @param separator character separating each statement (typically a ';')
* @param statements the list that will contain the individual statements
* @throws ScriptException if an error occurred while splitting the SQL script
* @see #splitSqlScript(String, String, List)
* @see #splitSqlScript(EncodedResource, String, String, String, String, String, List)
*/
public static void splitSqlScript(String script, char separator, List<String> statements) throws ScriptException {
splitSqlScript(script, String.valueOf(separator), statements);
}
/**
* Split an SQL script into separate statements delimited by the provided
* separator string. Each individual statement will be added to the
* provided {@code List}.
* <p>Within the script, {@value #DEFAULT_COMMENT_PREFIX} will be used as the
* comment prefix; any text beginning with the comment prefix and extending to
* the end of the line will be omitted from the output. Similarly,
* {@value #DEFAULT_BLOCK_COMMENT_START_DELIMITER} and
* {@value #DEFAULT_BLOCK_COMMENT_END_DELIMITER} will be used as the
* <em>start</em> and <em>end</em> block comment delimiters: any text enclosed
* in a block comment will be omitted from the output. In addition, multiple
* adjacent whitespace characters will be collapsed into a single space.
* @param script the SQL script
* @param separator text separating each statement
* (typically a ';' or newline character)
* @param statements the list that will contain the individual statements
* @throws ScriptException if an error occurred while splitting the SQL script
* @see #splitSqlScript(String, char, List)
* @see #splitSqlScript(EncodedResource, String, String, String, String, String, List)
*/
public static void splitSqlScript(String script, String separator, List<String> statements) throws ScriptException {
splitSqlScript(null, script, separator, DEFAULT_COMMENT_PREFIX, DEFAULT_BLOCK_COMMENT_START_DELIMITER,
DEFAULT_BLOCK_COMMENT_END_DELIMITER, statements);
}
/**
* Split an SQL script into separate statements delimited by the provided
* separator string. Each individual statement will be added to the provided
* {@code List}.
* <p>Within the script, the provided {@code commentPrefix} will be honored:
* any text beginning with the comment prefix and extending to the end of the
* line will be omitted from the output. Similarly, the provided
* {@code blockCommentStartDelimiter} and {@code blockCommentEndDelimiter}
* delimiters will be honored: any text enclosed in a block comment will be
* omitted from the output. In addition, multiple adjacent whitespace characters
* will be collapsed into a single space.
* @param resource the resource from which the script was read
* @param script the SQL script
* @param separator text separating each statement
* (typically a ';' or newline character)
* @param commentPrefix the prefix that identifies SQL line comments
* (typically "--")
* @param blockCommentStartDelimiter the <em>start</em> block comment delimiter;
* never {@code null} or empty
* @param blockCommentEndDelimiter the <em>end</em> block comment delimiter;
* never {@code null} or empty
* @param statements the list that will contain the individual statements
* @throws ScriptException if an error occurred while splitting the SQL script
*/
public static void splitSqlScript(@Nullable EncodedResource resource, String script,
String separator, String commentPrefix, String blockCommentStartDelimiter,
String blockCommentEndDelimiter, List<String> statements) throws ScriptException {
Assert.hasText(commentPrefix, "'commentPrefix' must not be null or empty");
splitSqlScript(resource, script, separator, new String[] { commentPrefix },
blockCommentStartDelimiter, blockCommentEndDelimiter, statements);
}
/**
* Split an SQL script into separate statements delimited by the provided
* separator string. Each individual statement will be added to the provided
* {@code List}.
* <p>Within the script, the provided {@code commentPrefixes} will be honored:
* any text beginning with one of the comment prefixes and extending to the
* end of the line will be omitted from the output. Similarly, the provided
* {@code blockCommentStartDelimiter} and {@code blockCommentEndDelimiter}
* delimiters will be honored: any text enclosed in a block comment will be
* omitted from the output. In addition, multiple adjacent whitespace characters
* will be collapsed into a single space.
* @param resource the resource from which the script was read
* @param script the SQL script
* @param separator text separating each statement
* (typically a ';' or newline character)
* @param commentPrefixes the prefixes that identify SQL line comments
* (typically "--")
* @param blockCommentStartDelimiter the <em>start</em> block comment delimiter;
* never {@code null} or empty
* @param blockCommentEndDelimiter the <em>end</em> block comment delimiter;
* never {@code null} or empty
* @param statements the list that will contain the individual statements
* @throws ScriptException if an error occurred while splitting the SQL script
*/
public static void splitSqlScript(@Nullable EncodedResource resource, String script,
String separator, String[] commentPrefixes, String blockCommentStartDelimiter,
String blockCommentEndDelimiter, List<String> statements) throws ScriptException {
Assert.hasText(script, "'script' must not be null or empty");
Assert.notNull(separator, "'separator' must not be null");
Assert.notEmpty(commentPrefixes, "'commentPrefixes' must not be null or empty");
for (String commentPrefix : commentPrefixes) {
Assert.hasText(commentPrefix, "'commentPrefixes' must not contain null or empty elements");
}
Assert.hasText(blockCommentStartDelimiter, "'blockCommentStartDelimiter' must not be null or empty");
Assert.hasText(blockCommentEndDelimiter, "'blockCommentEndDelimiter' must not be null or empty");
StringBuilder sb = new StringBuilder();
boolean inSingleQuote = false;
boolean inDoubleQuote = false;
boolean inEscape = false;
for (int i = 0; i < script.length(); i++) {
char c = script.charAt(i);
if (inEscape) {
inEscape = false;
sb.append(c);
continue;
}
// MySQL style escapes
if (c == '\\') {
inEscape = true;
sb.append(c);
continue;
}
if (!inDoubleQuote && (c == '\'')) {
inSingleQuote = !inSingleQuote;
}
else if (!inSingleQuote && (c == '"')) {
inDoubleQuote = !inDoubleQuote;
}
if (!inSingleQuote && !inDoubleQuote) {
if (script.startsWith(separator, i)) {
// We've reached the end of the current statement
if (sb.length() > 0) {
statements.add(sb.toString());
sb = new StringBuilder();
}
i += separator.length() - 1;
continue;
}
else if (startsWithAny(script, commentPrefixes, i)) {
// Skip over any content from the start of the comment to the EOL
int indexOfNextNewline = script.indexOf('\n', i);
if (indexOfNextNewline > i) {
i = indexOfNextNewline;
continue;
}
else {
// If there's no EOL, we must be at the end of the script, so stop here.
break;
}
}
else if (script.startsWith(blockCommentStartDelimiter, i)) {
// Skip over any block comments
int indexOfCommentEnd = script.indexOf(blockCommentEndDelimiter, i);
if (indexOfCommentEnd > i) {
i = indexOfCommentEnd + blockCommentEndDelimiter.length() - 1;
continue;
}
else {
throw new ScriptParseException(
"Missing block comment end delimiter: " + blockCommentEndDelimiter, resource);
}
}
else if (c == ' ' || c == '\r' || c == '\n' || c == '\t') {
// Avoid multiple adjacent whitespace characters
if (sb.length() > 0 && sb.charAt(sb.length() - 1) != ' ') {
c = ' ';
}
else {
continue;
}
}
}
sb.append(c);
}
if (StringUtils.hasText(sb)) {
statements.add(sb.toString());
}
}
/**
* Read a script from the given resource, using "{@code --}" as the comment prefix
* and "{@code ;}" as the statement separator, and build a String containing the lines.
* @param resource the {@code EncodedResource} to be read
* @return {@code String} containing the script lines
*/
public static Mono<String> readScript(EncodedResource resource, DataBufferFactory dataBufferFactory) {
return readScript(resource, dataBufferFactory, DEFAULT_COMMENT_PREFIXES, DEFAULT_STATEMENT_SEPARATOR,
DEFAULT_BLOCK_COMMENT_END_DELIMITER);
}
/**
* Read a script from the provided resource, using the supplied comment prefixes
* and statement separator, and build a {@code String} containing the lines.
* <p>Lines <em>beginning</em> with one of the comment prefixes are excluded
* from the results; however, line comments anywhere else &mdash; for example,
* within a statement &mdash; will be included in the results.
* @param resource the {@code EncodedResource} containing the script
* to be processed
* @param commentPrefixes the prefixes that identify comments in the SQL script
* (typically "--")
* @param separator the statement separator in the SQL script (typically ";")
* @param blockCommentEndDelimiter the <em>end</em> block comment delimiter
* @return a {@link Mono} of {@link String} containing the script lines that
* completes once the resource was loaded
*/
private static Mono<String> readScript(EncodedResource resource, DataBufferFactory dataBufferFactory,
@Nullable String[] commentPrefixes, @Nullable String separator, @Nullable String blockCommentEndDelimiter) {
return DataBufferUtils.join(DataBufferUtils.read(resource.getResource(), dataBufferFactory, 8192))
.handle((it, sink) -> {
try (InputStream is = it.asInputStream()) {
InputStreamReader in = resource.getCharset() != null ? new InputStreamReader(is, resource.getCharset())
: new InputStreamReader(is);
LineNumberReader lnr = new LineNumberReader(in);
String script = readScript(lnr, commentPrefixes, separator, blockCommentEndDelimiter);
sink.next(script);
sink.complete();
}
catch (Exception ex) {
sink.error(ex);
}
finally {
DataBufferUtils.release(it);
}
});
}
/**
* Read a script from the provided {@code LineNumberReader}, using the supplied
* comment prefix and statement separator, and build a {@code String} containing
* the lines.
* <p>Lines <em>beginning</em> with the comment prefix are excluded from the
* results; however, line comments anywhere else &mdash; for example, within
* a statement &mdash; will be included in the results.
* @param lineNumberReader the {@code LineNumberReader} containing the script
* to be processed
* @param lineCommentPrefix the prefix that identifies comments in the SQL script
* (typically "--")
* @param separator the statement separator in the SQL script (typically ";")
* @param blockCommentEndDelimiter the <em>end</em> block comment delimiter
* @return a {@code String} containing the script lines
* @throws IOException in case of I/O errors
*/
public static String readScript(LineNumberReader lineNumberReader, @Nullable String lineCommentPrefix,
@Nullable String separator, @Nullable String blockCommentEndDelimiter) throws IOException {
String[] lineCommentPrefixes = (lineCommentPrefix != null) ? new String[] { lineCommentPrefix } : null;
return readScript(lineNumberReader, lineCommentPrefixes, separator, blockCommentEndDelimiter);
}
/**
* Read a script from the provided {@code LineNumberReader}, using the supplied
* comment prefixes and statement separator, and build a {@code String} containing
* the lines.
* <p>Lines <em>beginning</em> with one of the comment prefixes are excluded
* from the results; however, line comments anywhere else &mdash; for example,
* within a statement &mdash; will be included in the results.
* @param lineNumberReader the {@code LineNumberReader} containing the script
* to be processed
* @param lineCommentPrefixes the prefixes that identify comments in the SQL script
* (typically "--")
* @param separator the statement separator in the SQL script (typically ";")
* @param blockCommentEndDelimiter the <em>end</em> block comment delimiter
* @return a {@code String} containing the script lines
* @throws IOException in case of I/O errors
*/
public static String readScript(LineNumberReader lineNumberReader, @Nullable String[] lineCommentPrefixes,
@Nullable String separator, @Nullable String blockCommentEndDelimiter) throws IOException {
String currentStatement = lineNumberReader.readLine();
StringBuilder scriptBuilder = new StringBuilder();
while (currentStatement != null) {
if ((blockCommentEndDelimiter != null && currentStatement.contains(blockCommentEndDelimiter)) ||
(lineCommentPrefixes != null && !startsWithAny(currentStatement, lineCommentPrefixes, 0))) {
if (scriptBuilder.length() > 0) {
scriptBuilder.append('\n');
}
scriptBuilder.append(currentStatement);
}
currentStatement = lineNumberReader.readLine();
}
appendSeparatorToScriptIfNecessary(scriptBuilder, separator);
return scriptBuilder.toString();
}
private static void appendSeparatorToScriptIfNecessary(StringBuilder scriptBuilder, @Nullable String separator) {
if (separator == null) {
return;
}
String trimmed = separator.trim();
if (trimmed.length() == separator.length()) {
return;
}
// separator ends in whitespace, so we might want to see if the script is trying
// to end the same way
if (scriptBuilder.lastIndexOf(trimmed) == scriptBuilder.length() - trimmed.length()) {
scriptBuilder.append(separator.substring(trimmed.length()));
}
}
private static boolean startsWithAny(String script, String[] prefixes, int offset) {
for (String prefix : prefixes) {
if (script.startsWith(prefix, offset)) {
return true;
}
}
return false;
}
/**
* Does the provided SQL script contain the specified delimiter?
* @param script the SQL script
* @param delim the string delimiting each statement - typically a ';' character
*/
public static boolean containsSqlScriptDelimiters(String script, String delim) {
boolean inLiteral = false;
boolean inEscape = false;
for (int i = 0; i < script.length(); i++) {
char c = script.charAt(i);
if (inEscape) {
inEscape = false;
continue;
}
// MySQL style escapes
if (c == '\\') {
inEscape = true;
continue;
}
if (c == '\'') {
inLiteral = !inLiteral;
}
if (!inLiteral && script.startsWith(delim, i)) {
return true;
}
}
return false;
}
/**
* Execute the given SQL script using default settings for statement
* separators, comment delimiters, and exception handling flags.
* <p>Statement separators and comments will be removed before executing
* individual statements within the supplied script.
* <p><strong>Warning</strong>: this method does <em>not</em> release the
* provided {@link Connection}.
* @param connection the R2DBC connection to use to execute the script; already
* configured and ready to use
* @param resource the resource to load the SQL script from; encoded with the
* current platform's default encoding
* @throws ScriptException if an error occurred while executing the SQL script
* @see #executeSqlScript(Connection, EncodedResource, DataBufferFactory, boolean, boolean, String[], String, String, String)
* @see #DEFAULT_STATEMENT_SEPARATOR
* @see #DEFAULT_COMMENT_PREFIX
* @see #DEFAULT_BLOCK_COMMENT_START_DELIMITER
* @see #DEFAULT_BLOCK_COMMENT_END_DELIMITER
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection
*/
public static Mono<Void> executeSqlScript(Connection connection, Resource resource) throws ScriptException {
return executeSqlScript(connection, new EncodedResource(resource));
}
/**
* Execute the given SQL script using default settings for statement
* separators, comment delimiters, and exception handling flags.
* <p>Statement separators and comments will be removed before executing
* individual statements within the supplied script.
* <p><strong>Warning</strong>: this method does <em>not</em> release the
* provided {@link Connection}.
* @param connection the R2DBC connection to use to execute the script; already
* configured and ready to use
* @param resource the resource (potentially associated with a specific encoding)
* to load the SQL script from
* @throws ScriptException if an error occurred while executing the SQL script
* @see #executeSqlScript(Connection, EncodedResource, DataBufferFactory, boolean, boolean, String[], String, String, String)
* @see #DEFAULT_STATEMENT_SEPARATOR
* @see #DEFAULT_COMMENT_PREFIX
* @see #DEFAULT_BLOCK_COMMENT_START_DELIMITER
* @see #DEFAULT_BLOCK_COMMENT_END_DELIMITER
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection
*/
public static Mono<Void> executeSqlScript(Connection connection, EncodedResource resource) throws ScriptException {
return executeSqlScript(connection, resource, new DefaultDataBufferFactory(), false, false, DEFAULT_COMMENT_PREFIX,
DEFAULT_STATEMENT_SEPARATOR, DEFAULT_BLOCK_COMMENT_START_DELIMITER, DEFAULT_BLOCK_COMMENT_END_DELIMITER);
}
/**
* Execute the given SQL script.
* <p>Statement separators and comments will be removed before executing
* individual statements within the supplied script.
* <p><strong>Warning</strong>: this method does <em>not</em> release the
* provided {@link Connection}.
* @param connection the R2DBC connection to use to execute the script; already
* configured and ready to use
* @param resource the resource (potentially associated with a specific encoding)
* to load the SQL script from
* @param continueOnError whether or not to continue without throwing an exception
* in the event of an error
* @param ignoreFailedDrops whether or not to continue in the event of specifically
* an error on a {@code DROP} statement
* @param commentPrefix the prefix that identifies single-line comments in the
* SQL script (typically "--")
* @param separator the script statement separator; defaults to
* {@value #DEFAULT_STATEMENT_SEPARATOR} if not specified and falls back to
* {@value #FALLBACK_STATEMENT_SEPARATOR} as a last resort; may be set to
* {@value #EOF_STATEMENT_SEPARATOR} to signal that the script contains a
* single statement without a separator
* @param blockCommentStartDelimiter the <em>start</em> block comment delimiter
* @param blockCommentEndDelimiter the <em>end</em> block comment delimiter
* @throws ScriptException if an error occurred while executing the SQL script
* @see #DEFAULT_STATEMENT_SEPARATOR
* @see #FALLBACK_STATEMENT_SEPARATOR
* @see #EOF_STATEMENT_SEPARATOR
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection
*/
public static Mono<Void> executeSqlScript(Connection connection, EncodedResource resource,
DataBufferFactory dataBufferFactory, boolean continueOnError, boolean ignoreFailedDrops, String commentPrefix,
@Nullable String separator, String blockCommentStartDelimiter, String blockCommentEndDelimiter)
throws ScriptException {
return executeSqlScript(connection, resource, dataBufferFactory, continueOnError,
ignoreFailedDrops, new String[] { commentPrefix }, separator,
blockCommentStartDelimiter, blockCommentEndDelimiter);
}
/**
* Execute the given SQL script.
* <p>Statement separators and comments will be removed before executing
* individual statements within the supplied script.
* <p><strong>Warning</strong>: this method does <em>not</em> release the
* provided {@link Connection}.
* @param connection the R2DBC connection to use to execute the script; already
* configured and ready to use
* @param resource the resource (potentially associated with a specific encoding)
* to load the SQL script from
* @param continueOnError whether or not to continue without throwing an exception
* in the event of an error
* @param ignoreFailedDrops whether or not to continue in the event of specifically
* an error on a {@code DROP} statement
* @param commentPrefixes the prefixes that identify single-line comments in the
* SQL script (typically "--")
* @param separator the script statement separator; defaults to
* {@value #DEFAULT_STATEMENT_SEPARATOR} if not specified and falls back to
* {@value #FALLBACK_STATEMENT_SEPARATOR} as a last resort; may be set to
* {@value #EOF_STATEMENT_SEPARATOR} to signal that the script contains a
* single statement without a separator
* @param blockCommentStartDelimiter the <em>start</em> block comment delimiter
* @param blockCommentEndDelimiter the <em>end</em> block comment delimiter
* @throws ScriptException if an error occurred while executing the SQL script
* @see #DEFAULT_STATEMENT_SEPARATOR
* @see #FALLBACK_STATEMENT_SEPARATOR
* @see #EOF_STATEMENT_SEPARATOR
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection
* @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection
*/
public static Mono<Void> executeSqlScript(Connection connection, EncodedResource resource, DataBufferFactory dataBufferFactory,
boolean continueOnError,
boolean ignoreFailedDrops, String[] commentPrefixes, @Nullable String separator,
String blockCommentStartDelimiter, String blockCommentEndDelimiter) throws ScriptException {
if (logger.isDebugEnabled()) {
logger.debug("Executing SQL script from " + resource);
}
long startTime = System.currentTimeMillis();
Mono<String> script = readScript(resource, dataBufferFactory, commentPrefixes, separator, blockCommentEndDelimiter)
.onErrorMap(IOException.class, ex -> new CannotReadScriptException(resource, ex));
AtomicInteger statementNumber = new AtomicInteger();
Flux<Void> executeScript = script.flatMapIterable(statement -> {
List<String> statements = new ArrayList<>();
String separatorToUse = separator;
if (separatorToUse == null) {
separatorToUse = DEFAULT_STATEMENT_SEPARATOR;
}
if (!EOF_STATEMENT_SEPARATOR.equals(separatorToUse) && !containsSqlScriptDelimiters(statement, separatorToUse)) {
separatorToUse = FALLBACK_STATEMENT_SEPARATOR;
}
splitSqlScript(resource, statement, separatorToUse, commentPrefixes, blockCommentStartDelimiter,
blockCommentEndDelimiter, statements);
return statements;
}).concatMap(statement -> {
statementNumber.incrementAndGet();
return runStatement(statement, connection, resource, continueOnError, ignoreFailedDrops, statementNumber);
});
if (logger.isDebugEnabled()) {
executeScript = executeScript.doOnComplete(() -> {
long elapsedTime = System.currentTimeMillis() - startTime;
logger.debug("Executed SQL script from " + resource + " in " + elapsedTime + " ms.");
});
}
return executeScript.onErrorMap(ex -> !(ex instanceof ScriptException),
ex -> new UncategorizedScriptException("Failed to execute database script from resource [" + resource + "]",
ex))
.then();
}
private static Publisher<? extends Void> runStatement(String statement, Connection connection,
EncodedResource resource, boolean continueOnError, boolean ignoreFailedDrops, AtomicInteger statementNumber) {
Mono<Long> execution = Flux.from(connection.createStatement(statement).execute())
.flatMap(Result::getRowsUpdated)
.collect(Collectors.summingLong(count -> count));
if (logger.isDebugEnabled()) {
execution = execution.doOnNext(rowsAffected -> logger.debug(rowsAffected + " returned as update count for SQL: " + statement));
}
return execution.onErrorResume(ex -> {
boolean dropStatement = StringUtils.startsWithIgnoreCase(statement.trim(), "drop");
if (continueOnError || (dropStatement && ignoreFailedDrops)) {
if (logger.isDebugEnabled()) {
logger.debug(ScriptStatementFailedException.buildErrorMessage(statement, statementNumber.get(), resource),
ex);
}
}
else {
return Mono.error(new ScriptStatementFailedException(statement, statementNumber.get(), resource, ex));
}
return Mono.empty();
}).then();
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
/**
* Thrown when we cannot determine anything more specific than "something went wrong while
* processing an SQL script": for example, a {@link io.r2dbc.spi.R2dbcException} from
* R2DBC that we cannot pinpoint more precisely.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public class UncategorizedScriptException extends ScriptException {
/**
* Create a new {@code UncategorizedScriptException}.
* @param message detailed message
*/
public UncategorizedScriptException(String message) {
super(message);
}
/**
* Create a new {@code UncategorizedScriptException}.
* @param message detailed message
* @param cause the root cause
*/
public UncategorizedScriptException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,6 @@
/**
* Provides extensible support for initializing databases through scripts.
*/
@org.springframework.lang.NonNullApi
@org.springframework.lang.NonNullFields
package org.springframework.r2dbc.connection.init;

View File

@ -0,0 +1,243 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import java.util.HashMap;
import java.util.Map;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import reactor.core.publisher.Mono;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Abstract {@link ConnectionFactory} implementation that routes
* {@link #create()} calls to one of various target
* {@link ConnectionFactory factories} based on a lookup key.
* The latter is typically (but not necessarily) determined from some
* subscriber context.
*
* <p> Allows to configure a {@link #setDefaultTargetConnectionFactory(Object)
* default ConnectionFactory} as fallback.
*
* <p> Calls to {@link #getMetadata()} are routed to the
* {@link #setDefaultTargetConnectionFactory(Object) default ConnectionFactory}
* if configured.
*
* @author Mark Paluch
* @author Jens Schauder
* @since 5.3
* @see #setTargetConnectionFactories
* @see #setDefaultTargetConnectionFactory
* @see #determineCurrentLookupKey()
*/
public abstract class AbstractRoutingConnectionFactory implements ConnectionFactory, InitializingBean {
private static final Object FALLBACK_MARKER = new Object();
@Nullable
private Map<?, ?> targetConnectionFactories;
@Nullable
private Object defaultTargetConnectionFactory;
private boolean lenientFallback = true;
private ConnectionFactoryLookup connectionFactoryLookup = new MapConnectionFactoryLookup();
@Nullable
private Map<Object, ConnectionFactory> resolvedConnectionFactories;
@Nullable
private ConnectionFactory resolvedDefaultConnectionFactory;
/**
* Specify the map of target {@link ConnectionFactory ConnectionFactories},
* with the lookup key as key. The mapped value can either be a corresponding
* {@link ConnectionFactory} instance or a connection factory name String (to be
* resolved via a {@link #setConnectionFactoryLookup ConnectionFactoryLookup}).
*
* <p>The key can be of arbitrary type; this class implements the generic lookup
* process only. The concrete key representation will be handled by
* {@link #resolveSpecifiedLookupKey(Object)} and {@link #determineCurrentLookupKey()}.
*/
public void setTargetConnectionFactories(Map<?, ?> targetConnectionFactories) {
this.targetConnectionFactories = targetConnectionFactories;
}
/**
* Specify the default target {@link ConnectionFactory}, if any.
*
* <p>The mapped value can either be a corresponding {@link ConnectionFactory}
* instance or a connection factory name {@link String} (to be resolved via a
* {@link #setConnectionFactoryLookup ConnectionFactoryLookup}).
*
* <p>This {@link ConnectionFactory} will be used as target if none of the keyed
* {@link #setTargetConnectionFactories targetConnectionFactories} match the
* {@link #determineCurrentLookupKey() current lookup key}.
*/
public void setDefaultTargetConnectionFactory(Object defaultTargetConnectionFactory) {
this.defaultTargetConnectionFactory = defaultTargetConnectionFactory;
}
/**
* Specify whether to apply a lenient fallback to the default {@link ConnectionFactory}
* if no specific {@link ConnectionFactory} could be found for the current lookup key.
*
* <p>Default is {@code true}, accepting lookup keys without a corresponding entry
* in the target {@link ConnectionFactory} map - simply falling back to the default
* {@link ConnectionFactory} in that case.
*
* <p>Switch this flag to {@code false} if you would prefer the fallback to only
* apply when no lookup key was emitted. Lookup keys without a {@link ConnectionFactory}
* entry will then lead to an {@link IllegalStateException}.
* @see #setTargetConnectionFactories
* @see #setDefaultTargetConnectionFactory
* @see #determineCurrentLookupKey()
*/
public void setLenientFallback(boolean lenientFallback) {
this.lenientFallback = lenientFallback;
}
/**
* Set the {@link ConnectionFactoryLookup} implementation to use for resolving
* connection factory name Strings in the {@link #setTargetConnectionFactories
* targetConnectionFactories} map.
*/
public void setConnectionFactoryLookup(ConnectionFactoryLookup connectionFactoryLookup) {
Assert.notNull(connectionFactoryLookup, "ConnectionFactoryLookup must not be null");
this.connectionFactoryLookup = connectionFactoryLookup;
}
@Override
public void afterPropertiesSet() {
Assert.notNull(this.targetConnectionFactories, "Property 'targetConnectionFactories' must not be null");
this.resolvedConnectionFactories = new HashMap<>(this.targetConnectionFactories.size());
this.targetConnectionFactories.forEach((key, value) -> {
Object lookupKey = resolveSpecifiedLookupKey(key);
ConnectionFactory connectionFactory = resolveSpecifiedConnectionFactory(value);
this.resolvedConnectionFactories.put(lookupKey, connectionFactory);
});
if (this.defaultTargetConnectionFactory != null) {
this.resolvedDefaultConnectionFactory = resolveSpecifiedConnectionFactory(this.defaultTargetConnectionFactory);
}
}
/**
* Resolve the given lookup key object, as specified in the
* {@link #setTargetConnectionFactories targetConnectionFactories} map,
* into the actual lookup key to be used for matching with the
* {@link #determineCurrentLookupKey() current lookup key}.
* <p>The default implementation simply returns the given key as-is.
* @param lookupKey the lookup key object as specified by the user
* @return the lookup key as needed for matching.
*/
protected Object resolveSpecifiedLookupKey(Object lookupKey) {
return lookupKey;
}
/**
* Resolve the specified connection factory object into a
* {@link ConnectionFactory} instance.
* <p>The default implementation handles {@link ConnectionFactory} instances
* and connection factory names (to be resolved via a
* {@link #setConnectionFactoryLookup ConnectionFactoryLookup}).
* @param connectionFactory the connection factory value object as specified in the
* {@link #setTargetConnectionFactories targetConnectionFactories} map
* @return the resolved {@link ConnectionFactory} (never {@code null})
* @throws IllegalArgumentException in case of an unsupported value type
*/
protected ConnectionFactory resolveSpecifiedConnectionFactory(Object connectionFactory)
throws IllegalArgumentException {
if (connectionFactory instanceof ConnectionFactory) {
return (ConnectionFactory) connectionFactory;
}
else if (connectionFactory instanceof String) {
return this.connectionFactoryLookup.getConnectionFactory((String) connectionFactory);
}
else {
throw new IllegalArgumentException(
"Illegal connection factory value - only 'io.r2dbc.spi.ConnectionFactory' and 'String' supported: "
+ connectionFactory);
}
}
@Override
public Mono<Connection> create() {
return determineTargetConnectionFactory() //
.map(ConnectionFactory::create) //
.flatMap(Mono::from);
}
@Override
public ConnectionFactoryMetadata getMetadata() {
if (this.resolvedDefaultConnectionFactory != null) {
return this.resolvedDefaultConnectionFactory.getMetadata();
}
throw new UnsupportedOperationException(
"No default ConnectionFactory configured to retrieve ConnectionFactoryMetadata");
}
/**
* Retrieve the current target {@link ConnectionFactory}. Determines the
* {@link #determineCurrentLookupKey() current lookup key}, performs a lookup
* in the {@link #setTargetConnectionFactories targetConnectionFactories} map,
* falls back to the specified {@link #setDefaultTargetConnectionFactory default
* target ConnectionFactory} if necessary.
* @return {@link Mono} emitting the current {@link ConnectionFactory} as
* per {@link #determineCurrentLookupKey()}
* @see #determineCurrentLookupKey()
*/
protected Mono<ConnectionFactory> determineTargetConnectionFactory() {
Assert.state(this.resolvedConnectionFactories != null, "ConnectionFactory router not initialized");
Mono<Object> lookupKey = determineCurrentLookupKey().defaultIfEmpty(FALLBACK_MARKER);
return lookupKey.handle((key, sink) -> {
ConnectionFactory connectionFactory = this.resolvedConnectionFactories.get(key);
if (connectionFactory == null && (key == FALLBACK_MARKER || this.lenientFallback)) {
connectionFactory = this.resolvedDefaultConnectionFactory;
}
if (connectionFactory == null) {
sink.error(new IllegalStateException(String.format(
"Cannot determine target ConnectionFactory for lookup key '%s'", key == FALLBACK_MARKER ? null : key)));
return;
}
sink.next(connectionFactory);
});
}
/**
* Determine the current lookup key. This will typically be implemented to check a
* subscriber context. Allows for arbitrary keys. The returned key needs to match the
* stored lookup key type, as resolved by the {@link #resolveSpecifiedLookupKey} method.
*
* @return {@link Mono} emitting the lookup key. May complete without emitting a value
* if no lookup key available
*/
protected abstract Mono<Object> determineCurrentLookupKey();
}

View File

@ -0,0 +1,85 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import io.r2dbc.spi.ConnectionFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* {@link ConnectionFactoryLookup} implementation based on a
* Spring {@link BeanFactory}.
*
* <p>Will lookup Spring managed beans identified by bean name,
* expecting them to be of type {@link ConnectionFactory}.
*
* @author Mark Paluch
* @since 5.3
* @see BeanFactory
*/
public class BeanFactoryConnectionFactoryLookup implements ConnectionFactoryLookup, BeanFactoryAware {
@Nullable
private BeanFactory beanFactory;
/**
* Create a new instance of the {@link BeanFactoryConnectionFactoryLookup} class.
* <p>The BeanFactory to access must be set via {@code setBeanFactory}.
* @see #setBeanFactory
*/
public BeanFactoryConnectionFactoryLookup() {}
/**
* Create a new instance of the {@link BeanFactoryConnectionFactoryLookup} class.
* <p>Use of this constructor is redundant if this object is being created
* by a Spring IoC container, as the supplied {@link BeanFactory} will be
* replaced by the {@link BeanFactory} that creates it (c.f. the
* {@link BeanFactoryAware} contract). So only use this constructor if you
* are using this class outside the context of a Spring IoC container.
* @param beanFactory the bean factory to be used to lookup {@link ConnectionFactory
* ConnectionFactories}
*/
public BeanFactoryConnectionFactoryLookup(BeanFactory beanFactory) {
Assert.notNull(beanFactory, "BeanFactory must not be null");
this.beanFactory = beanFactory;
}
@Override
public void setBeanFactory(BeanFactory beanFactory) {
this.beanFactory = beanFactory;
}
@Override
public ConnectionFactory getConnectionFactory(String connectionFactoryName)
throws ConnectionFactoryLookupFailureException {
Assert.state(this.beanFactory != null, "BeanFactory is required");
try {
return this.beanFactory.getBean(connectionFactoryName, ConnectionFactory.class);
}
catch (BeansException ex) {
throw new ConnectionFactoryLookupFailureException(
String.format("Failed to look up ConnectionFactory bean with name '%s'", connectionFactoryName), ex);
}
}
}

View File

@ -0,0 +1,38 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import io.r2dbc.spi.ConnectionFactory;
/**
* Strategy interface for looking up {@link ConnectionFactory} by name.
*
* @author Mark Paluch
* @since 5.3
*/
@FunctionalInterface
public interface ConnectionFactoryLookup {
/**
* Retrieve the {@link ConnectionFactory} identified by the given name.
* @param connectionFactoryName the name of the {@link ConnectionFactory}
* @return the {@link ConnectionFactory} (never {@code null})
* @throws ConnectionFactoryLookupFailureException if the lookup failed
*/
ConnectionFactory getConnectionFactory(String connectionFactoryName) throws ConnectionFactoryLookupFailureException;
}

View File

@ -0,0 +1,49 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import org.springframework.dao.NonTransientDataAccessException;
/**
* Exception to be thrown by a {@link ConnectionFactoryLookup} implementation,
* indicating that the specified {@link io.r2dbc.spi.ConnectionFactory} could
* not be obtained.
*
* @author Mark Paluch
* @since 5.3
*/
@SuppressWarnings("serial")
public class ConnectionFactoryLookupFailureException extends NonTransientDataAccessException {
/**
* Create a new {@code ConnectionFactoryLookupFailureException}.
* @param msg the detail message
*/
public ConnectionFactoryLookupFailureException(String msg) {
super(msg);
}
/**
* Create a new {@code ConnectionFactoryLookupFailureException}.
* @param msg the detail message
* @param cause the root cause
*/
public ConnectionFactoryLookupFailureException(String msg, Throwable cause) {
super(msg, cause);
}
}

View File

@ -0,0 +1,111 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import io.r2dbc.spi.ConnectionFactory;
import org.springframework.util.Assert;
/**
* Simple {@link ConnectionFactoryLookup} implementation that relies
* on a map for doing lookups.
*
* <p>Useful for testing environments or applications that need to match
* arbitrary {@link String} names to target {@link ConnectionFactory} objects.
*
* @author Mark Paluch
* @author Jens Schauder
* @since 5.3
*/
public class MapConnectionFactoryLookup implements ConnectionFactoryLookup {
private final Map<String, ConnectionFactory> connectionFactories = new HashMap<>();
/**
* Create a new instance of the {@link MapConnectionFactoryLookup} class.
*/
public MapConnectionFactoryLookup() {}
/**
* Create a new instance of the {@link MapConnectionFactoryLookup} class.
* @param connectionFactories the {@link Map} of {@link ConnectionFactory}.
* The keys are {@link String Strings}, the values are actual {@link ConnectionFactory} instances.
*/
public MapConnectionFactoryLookup(Map<String, ConnectionFactory> connectionFactories) {
setConnectionFactories(connectionFactories);
}
/**
* Create a new instance of the {@link MapConnectionFactoryLookup} class.
*
* @param connectionFactoryName the name under which the supplied {@link ConnectionFactory} is to be added
* @param connectionFactory the {@link ConnectionFactory} to be added
*/
public MapConnectionFactoryLookup(String connectionFactoryName, ConnectionFactory connectionFactory) {
addConnectionFactory(connectionFactoryName, connectionFactory);
}
/**
* Set the {@link Map} of {@link ConnectionFactory ConnectionFactories}.
* The keys are {@link String Strings}, the values are actual {@link ConnectionFactory} instances.
* <p>If the supplied {@link Map} is {@code null}, then this method call effectively has no effect.
* @param connectionFactories said {@link Map} of {@link ConnectionFactory connectionFactories}
*/
public void setConnectionFactories(Map<String, ConnectionFactory> connectionFactories) {
Assert.notNull(connectionFactories, "ConnectionFactories must not be null");
this.connectionFactories.putAll(connectionFactories);
}
/**
* Get the {@link Map} of {@link ConnectionFactory ConnectionFactories} maintained by this object.
* <p>The returned {@link Map} is {@link Collections#unmodifiableMap(Map) unmodifiable}.
* @return {@link Map} of {@link ConnectionFactory connectionFactory} (never {@code null})
*/
public Map<String, ConnectionFactory> getConnectionFactories() {
return Collections.unmodifiableMap(this.connectionFactories);
}
/**
* Add the supplied {@link ConnectionFactory} to the map of {@link ConnectionFactory ConnectionFactorys} maintained by
* this object.
*
* @param connectionFactoryName the name under which the supplied {@link ConnectionFactory} is to be added
* @param connectionFactory the {@link ConnectionFactory} to be so added
*/
public void addConnectionFactory(String connectionFactoryName, ConnectionFactory connectionFactory) {
Assert.notNull(connectionFactoryName, "ConnectionFactory name must not be null");
Assert.notNull(connectionFactory, "ConnectionFactory must not be null");
this.connectionFactories.put(connectionFactoryName, connectionFactory);
}
@Override
public ConnectionFactory getConnectionFactory(String connectionFactoryName)
throws ConnectionFactoryLookupFailureException {
Assert.notNull(connectionFactoryName, "ConnectionFactory name must not be null");
return this.connectionFactories.computeIfAbsent(connectionFactoryName, key -> {
throw new ConnectionFactoryLookupFailureException(
"No ConnectionFactory with name '" + connectionFactoryName + "' registered");
});
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import io.r2dbc.spi.ConnectionFactory;
import org.springframework.util.Assert;
/**
* An implementation of {@link ConnectionFactoryLookup} that
* simply wraps a single given {@link ConnectionFactory}
* returned for any connection factory name.
*
* @author Mark Paluch
* @since 5.3
*/
public class SingleConnectionFactoryLookup implements ConnectionFactoryLookup {
private final ConnectionFactory connectionFactory;
/**
* Create a new instance of the {@link SingleConnectionFactoryLookup} class.
* @param connectionFactory the single {@link ConnectionFactory} to wrap
*/
public SingleConnectionFactoryLookup(ConnectionFactory connectionFactory) {
Assert.notNull(connectionFactory, "ConnectionFactory must not be null");
this.connectionFactory = connectionFactory;
}
@Override
public ConnectionFactory getConnectionFactory(String connectionFactoryName)
throws ConnectionFactoryLookupFailureException {
return this.connectionFactory;
}
}

View File

@ -0,0 +1,9 @@
/**
* Provides a strategy for looking up R2DBC ConnectionFactories by name.
*/
@NonNullApi
@NonNullFields
package org.springframework.r2dbc.connection.lookup;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@ -0,0 +1,11 @@
/**
* Provides a utility class for easy ConnectionFactory access,
* a ReactiveTransactionManager for a single ConnectionFactory,
* and various simple ConnectionFactory implementations.
*/
@NonNullApi
@NonNullFields
package org.springframework.r2dbc.connection;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@ -0,0 +1,72 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import org.springframework.lang.Nullable;
/**
* Interface that defines common functionality for objects
* that can offer parameter values for named bind parameters,
* serving as argument for {@link NamedParameterExpander} operations.
*
* <p>This interface allows for the specification of the type in
* addition to parameter values. All parameter values and types are
* identified by specifying the name of the parameter.
*
* <p>Intended to wrap various implementations like a {@link java.util.Map}
* with a consistent interface.
*
* @author Mark Paluch
* @since 5.3
* @see MapBindParameterSource
*/
interface BindParameterSource {
/**
* Determine whether there is a value for the specified named parameter.
* @param paramName the name of the parameter
* @return {@code true} if there is a value defined; {@code false} otherwise
*/
boolean hasValue(String paramName);
/**
* Return the parameter value for the requested named parameter.
* @param paramName the name of the parameter
* @return the value of the specified parameter, can be {@code null}
* @throws IllegalArgumentException if there is no value
* for the requested parameter
*/
@Nullable
Object getValue(String paramName) throws IllegalArgumentException;
/**
* Determine the type for the specified named parameter.
* @param paramName the name of the parameter
* @return the type of the specified parameter, or
* {@link Object#getClass()} if not known.
*/
default Class<?> getType(String paramName) {
return Object.class;
}
/**
* Return parameter names of the underlying parameter source.
* @return parameter names of the underlying parameter source.
*/
Iterable<String> getParameterNames();
}

View File

@ -0,0 +1,102 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.Collection;
import java.util.Map;
import java.util.function.BiFunction;
import io.r2dbc.spi.ColumnMetadata;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedCaseInsensitiveMap;
/**
* {@link BiFunction Mapping function} implementation that creates a
* {@code java.util.Map} for each row, representing all columns as
* key-value pairs: one entry for each column, with the column name as key.
*
* <p>The Map implementation to use and the key to use for each column
* in the column Map can be customized through overriding
* {@link #createColumnMap} and {@link #getColumnKey}, respectively.
*
* <p><b>Note:</b> By default, ColumnMapRowMapper will try to build a linked Map
* with case-insensitive keys, to preserve column order as well as allow any
* casing to be used for column names. This requires Commons Collections on the
* classpath (which will be autodetected). Else, the fallback is a standard linked
* HashMap, which will still preserve column order but requires the application
* to specify the column names in the same casing as exposed by the driver.
*
* @author Mark Paluch
* @since 5.3
*/
public class ColumnMapRowMapper implements BiFunction<Row, RowMetadata, Map<String, Object>> {
/** Default instance. */
public final static ColumnMapRowMapper INSTANCE = new ColumnMapRowMapper();
@Override
public Map<String, Object> apply(Row row, RowMetadata rowMetadata) {
Collection<String> columns = rowMetadata.getColumnNames();
int columnCount = columns.size();
Map<String, Object> mapOfColValues = createColumnMap(columnCount);
int index = 0;
for (String column : columns) {
String key = getColumnKey(column);
Object obj = getColumnValue(row, index++);
mapOfColValues.put(key, obj);
}
return mapOfColValues;
}
/**
* Create a {@link Map} instance to be used as column map.
* <p>By default, a linked case-insensitive Map will be created.
* @param columnCount the column count, to be used as initial capacity for the Map
* @return the new {@link Map} instance
* @see LinkedCaseInsensitiveMap
*/
protected Map<String, Object> createColumnMap(int columnCount) {
return new LinkedCaseInsensitiveMap<>(columnCount);
}
/**
* Determine the key to use for the given column in the column {@link Map}.
* @param columnName the column name as returned by the {@link Row}
* @return the column key to use
* @see ColumnMetadata#getName()
*/
protected String getColumnKey(String columnName) {
return columnName;
}
/**
* Retrieve a R2DBC object value for the specified column.
* <p>The default implementation uses the {@link Row#get(int)} method.
* @param row is the {@link Row} holding the data
* @param index is the column index
* @return the Object returned
*/
@Nullable
protected Object getColumnValue(Row row, int index) {
return row.get(index);
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.function.Function;
import io.r2dbc.spi.Connection;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.dao.DataAccessException;
/**
* Interface declaring methods that accept callback {@link Function}
* to operate within the scope of a {@link Connection}.
* Callback functions operate on a provided connection and must not
* close the connection as the connections may be pooled or be
* subject to other kinds of resource management.
*
* <p> Callback functions are responsible for creating a
* {@link org.reactivestreams.Publisher} that defines the scope of how
* long the allocated {@link Connection} is valid. Connections are
* released after the publisher terminates.
*
* @author Mark Paluch
* @since 5.3
*/
public interface ConnectionAccessor {
/**
* Execute a callback {@link Function} within a {@link Connection} scope.
* The function is responsible for creating a {@link Mono}. The connection
* is released after the {@link Mono} terminates (or the subscription
* is cancelled). Connection resources must not be passed outside of the
* {@link Function} closure, otherwise resources may get defunct.
* @param action the callback object that specifies the connection action
* @return the resulting {@link Mono}
*/
<T> Mono<T> inConnection(Function<Connection, Mono<T>> action) throws DataAccessException;
/**
* Execute a callback {@link Function} within a {@link Connection} scope.
* The function is responsible for creating a {@link Flux}. The connection
* is released after the {@link Flux} terminates (or the subscription
* is cancelled). Connection resources must not be passed outside of the
* {@link Function} closure, otherwise resources may get defunct.
* @param action the callback object that specifies the connection action
* @return the resulting {@link Flux}
*/
<T> Flux<T> inConnectionMany(Function<Connection, Flux<T>> action) throws DataAccessException;
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.function.Function;
import io.r2dbc.spi.Connection;
/**
* Union type combining {@link Function} and {@link SqlProvider} to expose the SQL that is
* related to the underlying action.
*
* @author Mark Paluch
* @since 5.3
* @param <R> the type of the result of the function.
*/
class ConnectionFunction<R> implements Function<Connection, R>, SqlProvider {
private final String sql;
private final Function<Connection, R> function;
ConnectionFunction(String sql, Function<Connection, R> function) {
this.sql = sql;
this.function = function;
}
@Override
public R apply(Connection t) {
return this.function.apply(t);
}
@Override
public String getSql() {
return this.sql;
}
}

View File

@ -0,0 +1,250 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import io.r2dbc.spi.Statement;
import reactor.core.publisher.Mono;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
import org.springframework.util.Assert;
/**
* A non-blocking, reactive client for performing database calls requests with
* Reactive Streams back pressure. Provides a higher level, common API over
* R2DBC client libraries.
*
* <p>Use one of the static factory methods {@link #create(ConnectionFactory)}
* or obtain a {@link DatabaseClient#builder()} to create an instance.
*
* Usage example:
* <p><pre class="code">
* ConnectionFactory factory =
*
* DatabaseClient client = DatabaseClient.create(factory);
* Mono&gtActor;lt actor = client.sql("select first_name, last_name from t_actor")
* .map(row -> new Actor(row.get("first_name, String.class"),
* row.get("last_name, String.class")))
* .first();
* </pre>
*
* @author Mark Paluch
* @since 5.3
*/
public interface DatabaseClient extends ConnectionAccessor {
/**
* Specify a static {@code sql} statement to run. Contract for specifying a
* SQL call along with options leading to the execution. The SQL string can
* contain either native parameter bind markers or named parameters (e.g.
* {@literal :foo, :bar}) when {@link NamedParameterExpander} is enabled.
*
* @param sql must not be {@code null} or empty
* @return a new {@link GenericExecuteSpec}
* @see NamedParameterExpander
* @see DatabaseClient.Builder#namedParameters(boolean)
*/
GenericExecuteSpec sql(String sql);
/**
* Specify a {@link Supplier SQL supplier} that provides SQL to run.
* Contract for specifying a SQL call along with options leading to
* the execution. The SQL string can contain either native parameter
* bind markers or named parameters (e.g. {@literal :foo, :bar}) when
* {@link NamedParameterExpander} is enabled.
*
* <p>Accepts {@link PreparedOperation} as SQL and binding {@link Supplier}
* @param sqlSupplier must not be {@code null}
* @return a new {@link GenericExecuteSpec}
* @see NamedParameterExpander
* @see DatabaseClient.Builder#namedParameters(boolean)
* @see PreparedOperation
*/
GenericExecuteSpec sql(Supplier<String> sqlSupplier);
// Static, factory methods
/**
* Create a {@code DatabaseClient} that will use the provided {@link ConnectionFactory}.
* @param factory the {@code ConnectionFactory} to use for obtaining connections
* @return a new {@code DatabaseClient}. Guaranteed to be not {@code null}.
*/
static DatabaseClient create(ConnectionFactory factory) {
return new DefaultDatabaseClientBuilder().connectionFactory(factory).build();
}
/**
* Obtain a {@code DatabaseClient} builder.
*/
static DatabaseClient.Builder builder() {
return new DefaultDatabaseClientBuilder();
}
/**
* A mutable builder for creating a {@link DatabaseClient}.
*/
interface Builder {
/**
* Configure the {@link BindMarkersFactory BindMarkers} to be used.
* @param bindMarkers must not be {@code null}
*/
Builder bindMarkers(BindMarkersFactory bindMarkers);
/**
* Configure the {@link ConnectionFactory R2DBC connector}.
* @param factory must not be {@code null}
*/
Builder connectionFactory(ConnectionFactory factory);
/**
* Configure a {@link ExecuteFunction} to execute {@link Statement} objects.
* @param executeFunction must not be {@code null}
* @see Statement#execute()
*/
Builder executeFunction(ExecuteFunction executeFunction);
/**
* Configure whether to use named parameter expansion. Defaults to {@code true}.
* @param enabled {@code true} to use named parameter expansion.
* {@code false} to disable named parameter expansion.
* @see NamedParameterExpander
*/
Builder namedParameters(boolean enabled);
/**
* Configures a {@link Consumer} to configure this builder.
* @param builderConsumer must not be {@code null}.
*/
Builder apply(Consumer<Builder> builderConsumer);
/**
* Builder the {@link DatabaseClient} instance.
*/
DatabaseClient build();
}
/**
* Contract for specifying a SQL call along with options leading to the execution.
*/
interface GenericExecuteSpec {
/**
* Bind a non-{@code null} value to a parameter identified by its
* {@code index}. {@code value} can be either a scalar value or {@link Parameter}.
* @param index zero based index to bind the parameter to
* @param value must not be {@code null}. Can be either a scalar value or {@link Parameter}
*/
GenericExecuteSpec bind(int index, Object value);
/**
* Bind a {@code null} value to a parameter identified by its {@code index}.
* @param index zero based index to bind the parameter to
* @param type must not be {@code null}
*/
GenericExecuteSpec bindNull(int index, Class<?> type);
/**
* Bind a non-{@code null} value to a parameter identified by its {@code name}.
* @param name must not be {@code null} or empty
* @param value must not be {@code null}
*/
GenericExecuteSpec bind(String name, Object value);
/**
* Bind a {@code null} value to a parameter identified by its {@code name}.
* @param name must not be {@code null} or empty
* @param type must not be {@code null}
*/
GenericExecuteSpec bindNull(String name, Class<?> type);
/**
* Add the given filter to the end of the filter chain.
* <p>Filter functions are typically used to invoke methods on the Statement
* before it is executed.
*
* For example:
* <p><pre class="code">
* DatabaseClient client = ;
* client.sql("SELECT book_id FROM book").filter(statement -> statement.fetchSize(100))
* </pre>
* @param filter the filter to be added to the chain
*/
default GenericExecuteSpec filter(Function<? super Statement, ? extends Statement> filter) {
Assert.notNull(filter, "Statement FilterFunction must not be null");
return filter((statement, next) -> next.execute(filter.apply(statement)));
}
/**
* Add the given filter to the end of the filter chain.
* <p>Filter functions are typically used to invoke methods on the Statement
* before it is executed.
*
* For example:
* <p><pre class="code">
* DatabaseClient client = ;
* client.sql("SELECT book_id FROM book").filter((statement, next) -> next.execute(statement.fetchSize(100)))
* </pre>
* @param filter the filter to be added to the chain
*/
GenericExecuteSpec filter(StatementFilterFunction filter);
/**
* Configure a result mapping {@link Function function} and enter the execution stage.
* @param mappingFunction must not be {@code null}
* @param <R> result type.
* @return a {@link FetchSpec} for configuration what to fetch. Guaranteed to be not {@code null}.
*/
default <R> RowsFetchSpec<R> map(Function<Row, R> mappingFunction) {
Assert.notNull(mappingFunction, "Mapping function must not be null");
return map((row, rowMetadata) -> mappingFunction.apply(row));
}
/**
* Configure a result mapping {@link BiFunction function} and enter the execution stage.
* @param mappingFunction must not be {@code null}
* @param <R> result type.
* @return a {@link FetchSpec} for configuration what to fetch. Guaranteed to be not {@code null}.
*/
<R> RowsFetchSpec<R> map(BiFunction<Row, RowMetadata, R> mappingFunction);
/**
* Perform the SQL call and retrieve the result by entering the execution stage.
*/
FetchSpec<Map<String, Object>> fetch();
/**
* Perform the SQL call and return a {@link Mono} that completes without result on statement completion.
* @return a {@link Mono} ignoring its payload (actively dropping).
*/
Mono<Void> then();
}
}

View File

@ -0,0 +1,603 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.R2dbcException;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import io.r2dbc.spi.Statement;
import io.r2dbc.spi.Wrapped;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.lang.Nullable;
import org.springframework.r2dbc.connection.ConnectionFactoryUtils;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
import org.springframework.r2dbc.core.binding.BindTarget;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Default implementation of {@link DatabaseClient}.
*
* @author Mark Paluch
* @author Mingyuan Wu
* @author Bogdan Ilchyshyn
* @since 5.3
*/
class DefaultDatabaseClient implements DatabaseClient {
private final Log logger = LogFactory.getLog(getClass());
private final BindMarkersFactory bindMarkersFactory;
private final ConnectionFactory connectionFactory;
private final ExecuteFunction executeFunction;
private final boolean namedParameters;
@Nullable
private final NamedParameterExpander namedParameterExpander;
DefaultDatabaseClient(BindMarkersFactory bindMarkersFactory,
ConnectionFactory connectionFactory, ExecuteFunction executeFunction,
boolean namedParameters) {
this.bindMarkersFactory = bindMarkersFactory;
this.connectionFactory = connectionFactory;
this.executeFunction = executeFunction;
this.namedParameters = namedParameters;
this.namedParameterExpander = namedParameters ? new NamedParameterExpander()
: null;
}
@Override
public GenericExecuteSpec sql(String sql) {
Assert.hasText(sql, "SQL must not be null or empty");
return sql(() -> sql);
}
@Override
public GenericExecuteSpec sql(Supplier<String> sqlSupplier) {
Assert.notNull(sqlSupplier, "SQL Supplier must not be null");
return new DefaultGenericExecuteSpec(sqlSupplier);
}
@Override
public <T> Mono<T> inConnection(Function<Connection, Mono<T>> action)
throws DataAccessException {
Assert.notNull(action, "Callback object must not be null");
Mono<ConnectionCloseHolder> connectionMono = getConnection().map(
connection -> new ConnectionCloseHolder(connection, this::closeConnection));
return Mono.usingWhen(connectionMono, connectionCloseHolder -> {
// Create close-suppressing Connection proxy
Connection connectionToUse = createConnectionProxy(connectionCloseHolder.connection);
try {
return action.apply(connectionToUse);
}
catch (R2dbcException ex) {
String sql = getSql(action);
return Mono.error(ConnectionFactoryUtils.convertR2dbcException("doInConnection", sql, ex));
}
}, ConnectionCloseHolder::close, (it, err) -> it.close(),
ConnectionCloseHolder::close)
.onErrorMap(R2dbcException.class,
ex -> ConnectionFactoryUtils.convertR2dbcException("execute", getSql(action), ex));
}
@Override
public <T> Flux<T> inConnectionMany(Function<Connection, Flux<T>> action)
throws DataAccessException {
Assert.notNull(action, "Callback object must not be null");
Mono<ConnectionCloseHolder> connectionMono = getConnection().map(
connection -> new ConnectionCloseHolder(connection, this::closeConnection));
return Flux.usingWhen(connectionMono, connectionCloseHolder -> {
// Create close-suppressing Connection proxy, also preparing returned
// Statements.
Connection connectionToUse = createConnectionProxy(connectionCloseHolder.connection);
try {
return action.apply(connectionToUse);
}
catch (R2dbcException ex) {
String sql = getSql(action);
return Flux.error(ConnectionFactoryUtils.convertR2dbcException("doInConnectionMany", sql, ex));
}
}, ConnectionCloseHolder::close, (it, err) -> it.close(),
ConnectionCloseHolder::close)
.onErrorMap(R2dbcException.class,
ex -> ConnectionFactoryUtils.convertR2dbcException("executeMany", getSql(action), ex));
}
/**
* Obtain a {@link Connection}.
* @return a {@link Mono} able to emit a {@link Connection}
*/
private Mono<Connection> getConnection() {
return ConnectionFactoryUtils.getConnection(obtainConnectionFactory());
}
/**
* Release the {@link Connection}.
* @param connection to close.
* @return a {@link Publisher} that completes successfully when the connection is
* closed
*/
private Publisher<Void> closeConnection(Connection connection) {
return ConnectionFactoryUtils.currentConnectionFactory(
obtainConnectionFactory()).then().onErrorResume(Exception.class,
e -> Mono.from(connection.close()));
}
/**
* Obtain the {@link ConnectionFactory} for actual use.
* @return the ConnectionFactory (never {@code null})
*/
private ConnectionFactory obtainConnectionFactory() {
return this.connectionFactory;
}
/**
* Create a close-suppressing proxy for the given R2DBC
* Connection. Called by the {@code execute} method.
* @param con the R2DBC Connection to create a proxy for
* @return the Connection proxy
*/
private static Connection createConnectionProxy(Connection con) {
return (Connection) Proxy.newProxyInstance(DatabaseClient.class.getClassLoader(),
new Class<?>[] { Connection.class, Wrapped.class },
new CloseSuppressingInvocationHandler(con));
}
private static Mono<Integer> sumRowsUpdated(
Function<Connection, Flux<Result>> resultFunction, Connection it) {
return resultFunction.apply(it)
.flatMap(Result::getRowsUpdated)
.collect(Collectors.summingInt(Integer::intValue));
}
/**
* Determine SQL from potential provider object.
* @param sqlProvider object that's potentially a SqlProvider
* @return the SQL string, or {@code null}
* @see SqlProvider
*/
@Nullable
private static String getSql(Object sqlProvider) {
if (sqlProvider instanceof SqlProvider) {
return ((SqlProvider) sqlProvider).getSql();
}
else {
return null;
}
}
/**
* Base class for {@link DatabaseClient.GenericExecuteSpec} implementations.
*/
class DefaultGenericExecuteSpec implements GenericExecuteSpec {
final Map<Integer, Parameter> byIndex;
final Map<String, Parameter> byName;
final Supplier<String> sqlSupplier;
final StatementFilterFunction filterFunction;
DefaultGenericExecuteSpec(Supplier<String> sqlSupplier) {
this.byIndex = Collections.emptyMap();
this.byName = Collections.emptyMap();
this.sqlSupplier = sqlSupplier;
this.filterFunction = StatementFilterFunctions.empty();
}
DefaultGenericExecuteSpec(Map<Integer, Parameter> byIndex, Map<String, Parameter> byName,
Supplier<String> sqlSupplier, StatementFilterFunction filterFunction) {
this.byIndex = byIndex;
this.byName = byName;
this.sqlSupplier = sqlSupplier;
this.filterFunction = filterFunction;
}
@Override
public DefaultGenericExecuteSpec bind(int index, Object value) {
assertNotPreparedOperation();
Assert.notNull(value, () -> String.format(
"Value at index %d must not be null. Use bindNull(…) instead.",
index));
Map<Integer, Parameter> byIndex = new LinkedHashMap<>(this.byIndex);
if (value instanceof Parameter) {
byIndex.put(index, (Parameter) value);
}
else {
byIndex.put(index, Parameter.fromOrEmpty(value, value.getClass()));
}
return new DefaultGenericExecuteSpec(byIndex, this.byName, this.sqlSupplier, this.filterFunction);
}
@Override
public DefaultGenericExecuteSpec bindNull(int index, Class<?> type) {
assertNotPreparedOperation();
Map<Integer, Parameter> byIndex = new LinkedHashMap<>(this.byIndex);
byIndex.put(index, Parameter.empty(type));
return new DefaultGenericExecuteSpec(byIndex, this.byName, this.sqlSupplier, this.filterFunction);
}
@Override
public DefaultGenericExecuteSpec bind(String name, Object value) {
assertNotPreparedOperation();
Assert.hasText(name, "Parameter name must not be null or empty!");
Assert.notNull(value, () -> String.format(
"Value for parameter %s must not be null. Use bindNull(…) instead.",
name));
Map<String, Parameter> byName = new LinkedHashMap<>(this.byName);
if (value instanceof Parameter) {
byName.put(name, (Parameter) value);
}
else {
byName.put(name, Parameter.fromOrEmpty(value, value.getClass()));
}
return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction);
}
@Override
public DefaultGenericExecuteSpec bindNull(String name, Class<?> type) {
assertNotPreparedOperation();
Assert.hasText(name, "Parameter name must not be null or empty!");
Map<String, Parameter> byName = new LinkedHashMap<>(this.byName);
byName.put(name, Parameter.empty(type));
return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction);
}
@Override
public DefaultGenericExecuteSpec filter(StatementFilterFunction filter) {
Assert.notNull(filter, "Statement FilterFunction must not be null");
return new DefaultGenericExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, this.filterFunction.andThen(filter));
}
@Override
public <R> FetchSpec<R> map(BiFunction<Row, RowMetadata, R> mappingFunction) {
Assert.notNull(mappingFunction, "Mapping function must not be null");
return execute(this.sqlSupplier, mappingFunction);
}
@Override
public FetchSpec<Map<String, Object>> fetch() {
return execute(this.sqlSupplier, ColumnMapRowMapper.INSTANCE);
}
@Override
public Mono<Void> then() {
return fetch().rowsUpdated().then();
}
private <T> FetchSpec<T> execute(Supplier<String> sqlSupplier,
BiFunction<Row, RowMetadata, T> mappingFunction) {
String sql = getRequiredSql(sqlSupplier);
Function<Connection, Statement> statementFunction = connection -> {
if (logger.isDebugEnabled()) {
logger.debug("Executing SQL statement [" + sql + "]");
}
if (sqlSupplier instanceof PreparedOperation<?>) {
Statement statement = connection.createStatement(sql);
BindTarget bindTarget = new StatementWrapper(statement);
((PreparedOperation<?>) sqlSupplier).bindTo(bindTarget);
return statement;
}
if (DefaultDatabaseClient.this.namedParameters) {
Map<String, Parameter> remainderByName = new LinkedHashMap<>(
this.byName);
Map<Integer, Parameter> remainderByIndex = new LinkedHashMap<>(
this.byIndex);
MapBindParameterSource namedBindings = retrieveParameters(sql,
remainderByName, remainderByIndex);
PreparedOperation<String> operation = DefaultDatabaseClient.this.namedParameterExpander.expand(sql,
DefaultDatabaseClient.this.bindMarkersFactory, namedBindings);
String expanded = getRequiredSql(operation);
if (logger.isTraceEnabled()) {
logger.trace("Expanded SQL [" + expanded + "]");
}
Statement statement = connection.createStatement(expanded);
BindTarget bindTarget = new StatementWrapper(statement);
operation.bindTo(bindTarget);
bindByName(statement, remainderByName);
bindByIndex(statement, remainderByIndex);
return statement;
}
Statement statement = connection.createStatement(sql);
bindByIndex(statement, this.byIndex);
bindByName(statement, this.byName);
return statement;
};
Function<Connection, Flux<Result>> resultFunction = connection -> {
Statement statement = statementFunction.apply(connection);
return Flux.from(this.filterFunction.filter(statement, DefaultDatabaseClient.this.executeFunction))
.cast(Result.class).checkpoint("SQL \"" + sql + "\" [DatabaseClient]");
};
return new DefaultFetchSpec<>(
DefaultDatabaseClient.this, sql,
new ConnectionFunction<>(sql, resultFunction),
new ConnectionFunction<>(sql, connection -> sumRowsUpdated(resultFunction, connection)),
mappingFunction);
}
private MapBindParameterSource retrieveParameters(String sql,
Map<String, Parameter> remainderByName,
Map<Integer, Parameter> remainderByIndex) {
List<String> parameterNames = DefaultDatabaseClient.this.namedParameterExpander.getParameterNames(sql);
Map<String, Parameter> namedBindings = new LinkedHashMap<>(
parameterNames.size());
for (String parameterName : parameterNames) {
Parameter parameter = getParameter(remainderByName, remainderByIndex,
parameterNames, parameterName);
if (parameter == null) {
throw new InvalidDataAccessApiUsageException(
String.format("No parameter specified for [%s] in query [%s]",
parameterName, sql));
}
namedBindings.put(parameterName, parameter);
}
return new MapBindParameterSource(namedBindings);
}
@Nullable
private Parameter getParameter(Map<String, Parameter> remainderByName,
Map<Integer, Parameter> remainderByIndex, List<String> parameterNames,
String parameterName) {
if (this.byName.containsKey(parameterName)) {
remainderByName.remove(parameterName);
return this.byName.get(parameterName);
}
int index = parameterNames.indexOf(parameterName);
if (this.byIndex.containsKey(index)) {
remainderByIndex.remove(index);
return this.byIndex.get(index);
}
return null;
}
private void assertNotPreparedOperation() {
if (this.sqlSupplier instanceof PreparedOperation<?>) {
throw new InvalidDataAccessApiUsageException(
"Cannot add bindings to a PreparedOperation");
}
}
private void bindByName(Statement statement, Map<String, Parameter> byName) {
byName.forEach((name, parameter) -> {
if (parameter.hasValue()) {
statement.bind(name, parameter.getValue());
}
else {
statement.bindNull(name, parameter.getType());
}
});
}
private void bindByIndex(Statement statement, Map<Integer, Parameter> byIndex) {
byIndex.forEach((i, parameter) -> {
if (parameter.hasValue()) {
statement.bind(i, parameter.getValue());
}
else {
statement.bindNull(i, parameter.getType());
}
});
}
private String getRequiredSql(Supplier<String> sqlSupplier) {
String sql = sqlSupplier.get();
Assert.state(StringUtils.hasText(sql),
"SQL returned by SQL supplier must not be empty!");
return sql;
}
}
/**
* Invocation handler that suppresses close calls on R2DBC Connections. Also prepares
* returned Statement (Prepared/CallbackStatement) objects.
*
* @see Connection#close()
*/
private static class CloseSuppressingInvocationHandler implements InvocationHandler {
private final Connection target;
CloseSuppressingInvocationHandler(Connection target) {
this.target = target;
}
@Override
@Nullable
public Object invoke(Object proxy, Method method, Object[] args)
throws Throwable {
// Invocation on ConnectionProxy interface coming in...
if (method.getName().equals("equals")) {
// Only consider equal when proxies are identical.
return proxy == args[0];
}
else if (method.getName().equals("hashCode")) {
// Use hashCode of PersistenceManager proxy.
return System.identityHashCode(proxy);
}
else if (method.getName().equals("unwrap")) {
return this.target;
}
else if (method.getName().equals("close")) {
// Handle close method: suppress, not valid.
return Mono.error(
new UnsupportedOperationException("Close is not supported!"));
}
// Invoke method on target Connection.
try {
return method.invoke(this.target, args);
}
catch (InvocationTargetException ex) {
throw ex.getTargetException();
}
}
}
/**
* Holder for a connection that makes sure the close action is invoked atomically only
* once.
*/
static class ConnectionCloseHolder extends AtomicBoolean {
private static final long serialVersionUID = -8994138383301201380L;
final Connection connection;
final Function<Connection, Publisher<Void>> closeFunction;
ConnectionCloseHolder(Connection connection,
Function<Connection, Publisher<Void>> closeFunction) {
this.connection = connection;
this.closeFunction = closeFunction;
}
Mono<Void> close() {
return Mono.defer(() -> {
if (compareAndSet(false, true)) {
return Mono.from(this.closeFunction.apply(this.connection));
}
return Mono.empty();
});
}
}
static class StatementWrapper implements BindTarget {
final Statement statement;
StatementWrapper(Statement statement) {
this.statement = statement;
}
@Override
public void bind(String identifier, Object value) {
this.statement.bind(identifier, value);
}
@Override
public void bind(int index, Object value) {
this.statement.bind(index, value);
}
@Override
public void bindNull(String identifier, Class<?> type) {
this.statement.bindNull(identifier, type);
}
@Override
public void bindNull(int index, Class<?> type) {
this.statement.bindNull(index, type);
}
}
}

View File

@ -0,0 +1,106 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.function.Consumer;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Statement;
import org.springframework.lang.Nullable;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
import org.springframework.r2dbc.core.binding.BindMarkersFactoryResolver;
import org.springframework.util.Assert;
/**
* Default implementation of {@link DatabaseClient.Builder}.
*
* @author Mark Paluch
* @since 5.3
*/
class DefaultDatabaseClientBuilder implements DatabaseClient.Builder {
@Nullable
private BindMarkersFactory bindMarkers;
@Nullable
private ConnectionFactory connectionFactory;
private ExecuteFunction executeFunction = Statement::execute;
private boolean namedParameters = true;
DefaultDatabaseClientBuilder() {
}
@Override
public DatabaseClient.Builder bindMarkers(BindMarkersFactory bindMarkers) {
Assert.notNull(bindMarkers, "BindMarkersFactory must not be null");
this.bindMarkers = bindMarkers;
return this;
}
@Override
public DatabaseClient.Builder connectionFactory(ConnectionFactory factory) {
Assert.notNull(factory, "ConnectionFactory must not be null");
this.connectionFactory = factory;
return this;
}
@Override
public DatabaseClient.Builder executeFunction(ExecuteFunction executeFunction) {
Assert.notNull(executeFunction, "ExecuteFunction must not be null");
this.executeFunction = executeFunction;
return this;
}
@Override
public DatabaseClient.Builder namedParameters(boolean enabled) {
this.namedParameters = enabled;
return this;
}
@Override
public DatabaseClient build() {
Assert.notNull(this.connectionFactory, "ConnectionFactory must not be null");
BindMarkersFactory bindMarkers = this.bindMarkers;
if (bindMarkers == null) {
if (this.namedParameters) {
bindMarkers = BindMarkersFactoryResolver.resolve(this.connectionFactory);
}
else {
bindMarkers = BindMarkersFactory.anonymous("?");
}
}
return new DefaultDatabaseClient(bindMarkers, this.connectionFactory,
this.executeFunction, this.namedParameters);
}
@Override
public DatabaseClient.Builder apply(
Consumer<DatabaseClient.Builder> builderConsumer) {
Assert.notNull(builderConsumer, "BuilderConsumer must not be null");
builderConsumer.accept(this);
return this;
}
}

View File

@ -0,0 +1,100 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.function.BiFunction;
import java.util.function.Function;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
/**
* Default {@link FetchSpec} implementation.
*
* @author Mark Paluch
* @since 5.3
* @param <T> the row result type
*/
class DefaultFetchSpec<T> implements FetchSpec<T> {
private final ConnectionAccessor connectionAccessor;
private final String sql;
private final Function<Connection, Flux<Result>> resultFunction;
private final Function<Connection, Mono<Integer>> updatedRowsFunction;
private final BiFunction<Row, RowMetadata, T> mappingFunction;
DefaultFetchSpec(ConnectionAccessor connectionAccessor, String sql,
Function<Connection, Flux<Result>> resultFunction,
Function<Connection, Mono<Integer>> updatedRowsFunction,
BiFunction<Row, RowMetadata, T> mappingFunction) {
this.sql = sql;
this.connectionAccessor = connectionAccessor;
this.resultFunction = resultFunction;
this.updatedRowsFunction = updatedRowsFunction;
this.mappingFunction = mappingFunction;
}
@Override
public Mono<T> one() {
return all().buffer(2)
.flatMap(list -> {
if (list.isEmpty()) {
return Mono.empty();
}
if (list.size() > 1) {
return Mono.error(new IncorrectResultSizeDataAccessException(
String.format("Query [%s] returned non unique result.",
this.sql),
1));
}
return Mono.just(list.get(0));
}).next();
}
@Override
public Mono<T> first() {
return all().next();
}
@Override
public Flux<T> all() {
return this.connectionAccessor.inConnectionMany(new ConnectionFunction<>(this.sql,
connection -> this.resultFunction.apply(connection)
.flatMap(result -> result.map(this.mappingFunction))));
}
@Override
public Mono<Integer> rowsUpdated() {
return this.connectionAccessor.inConnection(this.updatedRowsFunction);
}
}

View File

@ -0,0 +1,57 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.function.BiFunction;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Statement;
import org.reactivestreams.Publisher;
/**
* Represents a function that executes a {@link Statement} for a (delayed)
* {@link Result} stream.
*
* <p>Note that discarded {@link Result} objects must be consumed according
* to the R2DBC spec via either {@link Result#getRowsUpdated()} or
* {@link Result#map(BiFunction)}.
*
* <p>Typically, implementations invoke the {@link Statement#execute()} method
* to initiate execution of the statement object.
*
* For example:
* <p><pre class="code">
* DatabaseClient.builder()
* .executeFunction(statement -> statement.execute())
* .build();
* </pre>
*
* @author Mark Paluch
* @since 5.3
* @see Statement#execute()
*/
@FunctionalInterface
public interface ExecuteFunction {
/**
* Execute the given {@link Statement} for a stream of {@link Result}s.
* @param statement the request to execute
* @return the delayed result stream
*/
Publisher<? extends Result> execute(Statement statement);
}

View File

@ -0,0 +1,28 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
/**
* Union type for fetching results.
*
* @author Mark Paluch
* @since 5.3
* @param <T> the row result type
* @see RowsFetchSpec
* @see UpdatedRowsFetchSpec
*/
public interface FetchSpec<T> extends RowsFetchSpec<T>, UpdatedRowsFetchSpec {}

View File

@ -0,0 +1,105 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.LinkedHashMap;
import java.util.Map;
import org.springframework.util.Assert;
/**
* {@link BindParameterSource} implementation that holds a given {@link Map} of parameters
* encapsulated as {@link Parameter}.
*
* <p>This class is intended for passing in a simple Map of parameter values to the methods
* of the {@code NamedParameterExpander} class.
*
* @author Mark Paluch
* @since 5.3
*/
class MapBindParameterSource implements BindParameterSource {
private final Map<String, Parameter> values;
/**
* Create a new empty {@link MapBindParameterSource}.
*/
MapBindParameterSource() {
this(new LinkedHashMap<>());
}
/**
* Creates a new {@link MapBindParameterSource} given {@link Map} of
* {@link Parameter}.
*
* @param values the parameter mapping.
*/
MapBindParameterSource(Map<String, Parameter> values) {
Assert.notNull(values, "Values must not be null");
this.values = values;
}
/**
* Add a key-value pair to the {@link MapBindParameterSource}. The value must not be
* {@code null}.
*
* @param paramName must not be {@code null}.
* @param value must not be {@code null}.
* @return {@code this} {@link MapBindParameterSource}
*/
MapBindParameterSource addValue(String paramName, Object value) {
Assert.notNull(paramName, "Parameter name must not be null");
Assert.notNull(value, "Value must not be null");
this.values.put(paramName, Parameter.fromOrEmpty(value, value.getClass()));
return this;
}
@Override
public boolean hasValue(String paramName) {
Assert.notNull(paramName, "Parameter name must not be null");
return this.values.containsKey(paramName);
}
@Override
public Class<?> getType(String paramName) {
Assert.notNull(paramName, "Parameter name must not be null");
Parameter parameter = this.values.get(paramName);
if (parameter != null) {
return parameter.getType();
}
return Object.class;
}
@Override
public Object getValue(String paramName) throws IllegalArgumentException {
if (!hasValue(paramName)) {
throw new IllegalArgumentException(
"No value registered for key '" + paramName + "'");
}
return this.values.get(paramName).getValue();
}
@Override
public Iterable<String> getParameterNames() {
return this.values.keySet();
}
}

View File

@ -0,0 +1,162 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
/**
* SQL translation support allowing the use of named parameters
* rather than native placeholders.
*
* <p>This class expands SQL from named parameters to native
* style placeholders at execution time. It also allows for expanding
* a {@link List} of values to the appropriate number of placeholders.
*
* <p>References to the same parameter name are substituted with the
* same bind marker placeholder if a {@link BindMarkersFactory} uses
* {@link BindMarkersFactory#identifiablePlaceholders() identifiable} placeholders.
* <p><b>NOTE: An instance of this class is thread-safe once configured.</b>
*
* @author Mark Paluch
*/
class NamedParameterExpander {
/**
* Default maximum number of entries for the SQL cache: 256.
*/
public static final int DEFAULT_CACHE_LIMIT = 256;
private volatile int cacheLimit = DEFAULT_CACHE_LIMIT;
private final Log logger = LogFactory.getLog(getClass());
/**
* Cache of original SQL String to ParsedSql representation.
*/
@SuppressWarnings("serial")
private final Map<String, ParsedSql> parsedSqlCache = new LinkedHashMap<String, ParsedSql>(
DEFAULT_CACHE_LIMIT, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, ParsedSql> eldest) {
return size() > getCacheLimit();
}
};
/**
* Create a new enabled instance of {@link NamedParameterExpander}.
*/
public NamedParameterExpander() {}
/**
* Specify the maximum number of entries for the SQL cache. Default is 256.
*/
public void setCacheLimit(int cacheLimit) {
this.cacheLimit = cacheLimit;
}
/**
* Return the maximum number of entries for the SQL cache.
*/
public int getCacheLimit() {
return this.cacheLimit;
}
/**
* Obtain a parsed representation of the given SQL statement.
* <p>
* The default implementation uses an LRU cache with an upper limit of 256 entries.
*
* @param sql the original SQL statement
* @return a representation of the parsed SQL statement
*/
private ParsedSql getParsedSql(String sql) {
if (getCacheLimit() <= 0) {
return NamedParameterUtils.parseSqlStatement(sql);
}
synchronized (this.parsedSqlCache) {
ParsedSql parsedSql = this.parsedSqlCache.get(sql);
if (parsedSql == null) {
parsedSql = NamedParameterUtils.parseSqlStatement(sql);
this.parsedSqlCache.put(sql, parsedSql);
}
return parsedSql;
}
}
/**
* Parse the SQL statement and locate any placeholders or named parameters.
* Named parameters are substituted for a native placeholder, and any
* select list is expanded to the required number of placeholders. Select
* lists may contain an array of objects, and in that case the placeholders
* will be grouped and enclosed with parentheses. This allows for the use of
* "expression lists" in the SQL statement like:
*
* <pre class="code">
* select id, name, state from table where (name, age) in (('John', 35), ('Ann', 50))
* </pre>
*
* <p>The parameter values passed in are used to determine the number of
* placeholders to be used for a select list. Select lists should be limited
* to 100 or fewer elements. A larger number of elements is not guaranteed to be
* supported by the database and is strictly vendor-dependent.
* @param sql sql the original SQL statement
* @param bindMarkersFactory the bind marker factory
* @param paramSource the source for named parameters
* @return the expanded sql that accepts bind parameters and allows for execution
* without further translation wrapped as {@link PreparedOperation}.
*/
public PreparedOperation<String> expand(String sql, BindMarkersFactory bindMarkersFactory,
BindParameterSource paramSource) {
ParsedSql parsedSql = getParsedSql(sql);
PreparedOperation<String> expanded = NamedParameterUtils.substituteNamedParameters(parsedSql, bindMarkersFactory,
paramSource);
if (logger.isDebugEnabled()) {
logger.debug(String.format("Expanding SQL statement [%s] to [%s]", sql, expanded.toQuery()));
}
return expanded;
}
/**
* Parse the SQL statement and locate any placeholders or named parameters. Named parameters are returned as result of
* this method invocation.
*
* @return the parameter names.
*/
public List<String> getParameterNames(String sql) {
return getParsedSql(sql).getParameterNames();
}
}

View File

@ -0,0 +1,630 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.lang.Nullable;
import org.springframework.r2dbc.core.binding.BindMarker;
import org.springframework.r2dbc.core.binding.BindMarkers;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
import org.springframework.r2dbc.core.binding.BindTarget;
import org.springframework.util.Assert;
/**
* Helper methods for named parameter parsing.
*
* <p>Only intended for internal use within Spring's R2DBC
* framework.
*
* <p>References to the same parameter name are substituted with
* the same bind marker placeholder if a {@link BindMarkersFactory} uses
* {@link BindMarkersFactory#identifiablePlaceholders() identifiable}
* placeholders.
*
* @author Thomas Risberg
* @author Juergen Hoeller
* @author Mark Paluch
* @since 5.3
*/
abstract class NamedParameterUtils {
/**
* Set of characters that qualify as comment or quotes starting characters.
*/
private static final String[] START_SKIP = new String[] {"'", "\"", "--", "/*"};
/**
* Set of characters that at are the corresponding comment or quotes ending characters.
*/
private static final String[] STOP_SKIP = new String[] {"'", "\"", "\n", "*/"};
/**
* Set of characters that qualify as parameter separators,
* indicating that a parameter name in an SQL String has ended.
*/
private static final String PARAMETER_SEPARATORS = "\"':&,;()|=+-*%/\\<>^";
/**
* An index with separator flags per character code.
* Technically only needed between 34 and 124 at this point.
*/
private static final boolean[] separatorIndex = new boolean[128];
static {
for (char c : PARAMETER_SEPARATORS.toCharArray()) {
separatorIndex[c] = true;
}
}
// -------------------------------------------------------------------------
// Core methods used by NamedParameterSupport.
// -------------------------------------------------------------------------
/**
* Parse the SQL statement and locate any placeholders or named parameters.
* Namedparameters are substituted for a R2DBC placeholder.
*
* @param sql the SQL statement
* @return the parsed statement, represented as {@link ParsedSql} instance.
*/
public static ParsedSql parseSqlStatement(String sql) {
Assert.notNull(sql, "SQL must not be null");
Set<String> namedParameters = new HashSet<>();
String sqlToUse = sql;
List<ParameterHolder> parameterList = new ArrayList<>();
char[] statement = sql.toCharArray();
int namedParameterCount = 0;
int unnamedParameterCount = 0;
int totalParameterCount = 0;
int escapes = 0;
int i = 0;
while (i < statement.length) {
int skipToPosition = i;
while (i < statement.length) {
skipToPosition = skipCommentsAndQuotes(statement, i);
if (i == skipToPosition) {
break;
}
else {
i = skipToPosition;
}
}
if (i >= statement.length) {
break;
}
char c = statement[i];
if (c == ':' || c == '&') {
int j = i + 1;
if (c == ':' && j < statement.length && statement[j] == ':') {
// Postgres-style "::" casting operator should be skipped
i = i + 2;
continue;
}
String parameter = null;
if (c == ':' && j < statement.length && statement[j] == '{') {
// :{x} style parameter
while (statement[j] != '}') {
j++;
if (j >= statement.length) {
throw new InvalidDataAccessApiUsageException("Non-terminated named parameter declaration " +
"at position " + i + " in statement: " + sql);
}
if (statement[j] == ':' || statement[j] == '{') {
throw new InvalidDataAccessApiUsageException("Parameter name contains invalid character '" +
statement[j] + "' at position " + i + " in statement: " + sql);
}
}
if (j - i > 2) {
parameter = sql.substring(i + 2, j);
namedParameterCount = addNewNamedParameter(namedParameters,
namedParameterCount, parameter);
totalParameterCount = addNamedParameter(parameterList,
totalParameterCount, escapes, i, j + 1, parameter);
}
j++;
}
else {
while (j < statement.length && !isParameterSeparator(statement[j])) {
j++;
}
if (j - i > 1) {
parameter = sql.substring(i + 1, j);
namedParameterCount = addNewNamedParameter(namedParameters,
namedParameterCount, parameter);
totalParameterCount = addNamedParameter(parameterList,
totalParameterCount, escapes, i, j, parameter);
}
}
i = j - 1;
}
else {
if (c == '\\') {
int j = i + 1;
if (j < statement.length && statement[j] == ':') {
// escaped ":" should be skipped
sqlToUse = sqlToUse.substring(0, i - escapes)
+ sqlToUse.substring(i - escapes + 1);
escapes++;
i = i + 2;
continue;
}
}
}
i++;
}
ParsedSql parsedSql = new ParsedSql(sqlToUse);
for (ParameterHolder ph : parameterList) {
parsedSql.addNamedParameter(ph.getParameterName(), ph.getStartIndex(), ph.getEndIndex());
}
parsedSql.setNamedParameterCount(namedParameterCount);
parsedSql.setUnnamedParameterCount(unnamedParameterCount);
parsedSql.setTotalParameterCount(totalParameterCount);
return parsedSql;
}
private static int addNamedParameter(
List<ParameterHolder> parameterList, int totalParameterCount, int escapes, int i, int j, String parameter) {
parameterList.add(new ParameterHolder(parameter, i - escapes, j - escapes));
totalParameterCount++;
return totalParameterCount;
}
private static int addNewNamedParameter(Set<String> namedParameters, int namedParameterCount, String parameter) {
if (!namedParameters.contains(parameter)) {
namedParameters.add(parameter);
namedParameterCount++;
}
return namedParameterCount;
}
/**
* Skip over comments and quoted names present in an SQL statement.
* @param statement character array containing SQL statement
* @param position current position of statement
* @return next position to process after any comments or quotes are skipped
*/
private static int skipCommentsAndQuotes(char[] statement, int position) {
for (int i = 0; i < START_SKIP.length; i++) {
if (statement[position] == START_SKIP[i].charAt(0)) {
boolean match = true;
for (int j = 1; j < START_SKIP[i].length(); j++) {
if (statement[position + j] != START_SKIP[i].charAt(j)) {
match = false;
break;
}
}
if (match) {
int offset = START_SKIP[i].length();
for (int m = position + offset; m < statement.length; m++) {
if (statement[m] == STOP_SKIP[i].charAt(0)) {
boolean endMatch = true;
int endPos = m;
for (int n = 1; n < STOP_SKIP[i].length(); n++) {
if (m + n >= statement.length) {
// last comment not closed properly
return statement.length;
}
if (statement[m + n] != STOP_SKIP[i].charAt(n)) {
endMatch = false;
break;
}
endPos = m + n;
}
if (endMatch) {
// found character sequence ending comment or quote
return endPos + 1;
}
}
}
// character sequence ending comment or quote not found
return statement.length;
}
}
}
return position;
}
/**
* Parse the SQL statement and locate any placeholders or named parameters. Named
* parameters are substituted for a R2DBC placeholder, and any select list is expanded
* to the required number of placeholders. Select lists may contain an array of
* objects, and in that case the placeholders will be grouped and enclosed with
* parentheses. This allows for the use of "expression lists" in the SQL statement
* like: <br /><br />
* {@code select id, name, state from table where (name, age) in (('John', 35), ('Ann', 50))}
* <p>The parameter values passed in are used to determine the number of placeholders to
* be used for a select list. Select lists should be limited to 100 or fewer elements.
* A larger number of elements is not guaranteed to be supported by the database and
* is strictly vendor-dependent.
* @param parsedSql the parsed representation of the SQL statement
* @param bindMarkersFactory the bind marker factory.
* @param paramSource the source for named parameters
* @return the expanded query that accepts bind parameters and allows for execution
* without further translation
* @see #parseSqlStatement
*/
public static PreparedOperation<String> substituteNamedParameters(ParsedSql parsedSql,
BindMarkersFactory bindMarkersFactory, BindParameterSource paramSource) {
NamedParameters markerHolder = new NamedParameters(bindMarkersFactory);
String originalSql = parsedSql.getOriginalSql();
List<String> paramNames = parsedSql.getParameterNames();
if (paramNames.isEmpty()) {
return new ExpandedQuery(originalSql, markerHolder, paramSource);
}
StringBuilder actualSql = new StringBuilder(originalSql.length());
int lastIndex = 0;
for (int i = 0; i < paramNames.size(); i++) {
String paramName = paramNames.get(i);
int[] indexes = parsedSql.getParameterIndexes(i);
int startIndex = indexes[0];
int endIndex = indexes[1];
actualSql.append(originalSql, lastIndex, startIndex);
NamedParameters.NamedParameter marker = markerHolder.getOrCreate(paramName);
if (paramSource.hasValue(paramName)) {
Object value = paramSource.getValue(paramName);
if (value instanceof Collection) {
Iterator<?> entryIter = ((Collection<?>) value).iterator();
int k = 0;
int counter = 0;
while (entryIter.hasNext()) {
if (k > 0) {
actualSql.append(", ");
}
k++;
Object entryItem = entryIter.next();
if (entryItem instanceof Object[]) {
Object[] expressionList = (Object[]) entryItem;
actualSql.append('(');
for (int m = 0; m < expressionList.length; m++) {
if (m > 0) {
actualSql.append(", ");
}
actualSql.append(marker.getPlaceholder(counter));
counter++;
}
actualSql.append(')');
}
else {
actualSql.append(marker.getPlaceholder(counter));
counter++;
}
}
}
else {
actualSql.append(marker.getPlaceholder());
}
}
else {
actualSql.append(marker.getPlaceholder());
}
lastIndex = endIndex;
}
actualSql.append(originalSql, lastIndex, originalSql.length());
return new ExpandedQuery(actualSql.toString(), markerHolder, paramSource);
}
/**
* Determine whether a parameter name ends at the current position,
* that is, whether the given character qualifies as a separator.
*/
private static boolean isParameterSeparator(char c) {
return (c < 128 && separatorIndex[c]) || Character.isWhitespace(c);
}
// -------------------------------------------------------------------------
// Convenience methods operating on a plain SQL String
// -------------------------------------------------------------------------
/**
* Parse the SQL statement and locate any placeholders or named parameters.
* Named parameters are substituted for a native placeholder and any
* select list is expanded to the required number of placeholders.
* @param sql the SQL statement
* @param bindMarkersFactory the bind marker factory
* @param paramSource the source for named parameters
* @return the expanded query that accepts bind parameters and allows for execution
* without further translation
*/
public static PreparedOperation<String> substituteNamedParameters(String sql,
BindMarkersFactory bindMarkersFactory, BindParameterSource paramSource) {
ParsedSql parsedSql = parseSqlStatement(sql);
return substituteNamedParameters(parsedSql, bindMarkersFactory, paramSource);
}
private static class ParameterHolder {
private final String parameterName;
private final int startIndex;
private final int endIndex;
ParameterHolder(String parameterName, int startIndex, int endIndex) {
this.parameterName = parameterName;
this.startIndex = startIndex;
this.endIndex = endIndex;
}
String getParameterName() {
return this.parameterName;
}
int getStartIndex() {
return this.startIndex;
}
int getEndIndex() {
return this.endIndex;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof ParameterHolder)) {
return false;
}
ParameterHolder that = (ParameterHolder) o;
return this.startIndex == that.startIndex && this.endIndex == that.endIndex
&& Objects.equals(this.parameterName, that.parameterName);
}
@Override
public int hashCode() {
return Objects.hash(this.parameterName, this.startIndex, this.endIndex);
}
}
/**
* Holder for bind markers progress.
*/
static class NamedParameters {
private final BindMarkers bindMarkers;
private final boolean identifiable;
private final Map<String, List<NamedParameter>> references = new TreeMap<>();
NamedParameters(BindMarkersFactory factory) {
this.bindMarkers = factory.create();
this.identifiable = factory.identifiablePlaceholders();
}
/**
* Get the {@link NamedParameter} identified by {@code namedParameter}.
* Parameter objects get created if they do not yet exist.
* @param namedParameter the parameter name
* @return the named parameter
*/
NamedParameter getOrCreate(String namedParameter) {
List<NamedParameter> reference = this.references.computeIfAbsent(
namedParameter, ignore -> new ArrayList<>());
if (reference.isEmpty()) {
NamedParameter param = new NamedParameter(namedParameter);
reference.add(param);
return param;
}
if (this.identifiable) {
return reference.get(0);
}
NamedParameter param = new NamedParameter(namedParameter);
reference.add(param);
return param;
}
@Nullable
List<NamedParameter> getMarker(String name) {
return this.references.get(name);
}
class NamedParameter {
private final String namedParameter;
private final List<BindMarker> placeholders = new ArrayList<>();
NamedParameter(String namedParameter) {
this.namedParameter = namedParameter;
}
/**
* Create a placeholder to translate a single value into a bindable parameter.
* <p>Can be called multiple times to create placeholders for array/collections.
* @return the placeholder to be used in the SQL statement
*/
String addPlaceholder() {
BindMarker bindMarker = NamedParameters.this.bindMarkers.next(
this.namedParameter);
this.placeholders.add(bindMarker);
return bindMarker.getPlaceholder();
}
String getPlaceholder() {
return getPlaceholder(0);
}
String getPlaceholder(int counter) {
while (counter + 1 > this.placeholders.size()) {
addPlaceholder();
}
return this.placeholders.get(counter).getPlaceholder();
}
}
}
/**
* Expanded query that allows binding of parameters using parameter names that were
* used to expand the query. Binding unrolls {@link Collection}s and nested arrays.
*/
private static class ExpandedQuery implements PreparedOperation<String> {
private final String expandedSql;
private final NamedParameters parameters;
private final BindParameterSource parameterSource;
ExpandedQuery(String expandedSql, NamedParameters parameters,
BindParameterSource parameterSource) {
this.expandedSql = expandedSql;
this.parameters = parameters;
this.parameterSource = parameterSource;
}
@SuppressWarnings("unchecked")
public void bind(BindTarget target, String identifier, Object value) {
List<BindMarker> bindMarkers = getBindMarkers(identifier);
if (bindMarkers == null) {
target.bind(identifier, value);
return;
}
if (value instanceof Collection) {
Collection<Object> collection = (Collection<Object>) value;
Iterator<Object> iterator = collection.iterator();
Iterator<BindMarker> markers = bindMarkers.iterator();
while (iterator.hasNext()) {
Object valueToBind = iterator.next();
if (valueToBind instanceof Object[]) {
Object[] objects = (Object[]) valueToBind;
for (Object object : objects) {
bind(target, markers, object);
}
}
else {
bind(target, markers, valueToBind);
}
}
}
else {
for (BindMarker bindMarker : bindMarkers) {
bindMarker.bind(target, value);
}
}
}
private void bind(BindTarget target, Iterator<BindMarker> markers,
Object valueToBind) {
Assert.isTrue(markers.hasNext(), () -> String.format(
"No bind marker for value [%s] in SQL [%s]. Check that the query was expanded using the same arguments.",
valueToBind, toQuery()));
markers.next().bind(target, valueToBind);
}
public void bindNull(BindTarget target, String identifier, Class<?> valueType) {
List<BindMarker> bindMarkers = getBindMarkers(identifier);
if (bindMarkers == null) {
target.bindNull(identifier, valueType);
return;
}
for (BindMarker bindMarker : bindMarkers) {
bindMarker.bindNull(target, valueType);
}
}
@Nullable
List<BindMarker> getBindMarkers(String identifier) {
List<NamedParameters.NamedParameter> parameters = this.parameters.getMarker(
identifier);
if (parameters == null) {
return null;
}
List<BindMarker> markers = new ArrayList<>();
for (NamedParameters.NamedParameter parameter : parameters) {
markers.addAll(parameter.placeholders);
}
return markers;
}
@Override
public String getSource() {
return this.expandedSql;
}
@Override
public void bindTo(BindTarget target) {
for (String namedParameter : this.parameterSource.getParameterNames()) {
Object value = this.parameterSource.getValue(namedParameter);
if (value == null) {
bindNull(target, namedParameter,
this.parameterSource.getType(namedParameter));
}
else {
bind(target, namedParameter, value);
}
}
}
@Override
public String toQuery() {
return this.expandedSql;
}
}
}

View File

@ -0,0 +1,137 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.Objects;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
/**
* A database value that can be set in a statement.
*
* @author Mark Paluch
* @since 5.3
*/
public final class Parameter {
@Nullable
private final Object value;
private final Class<?> type;
private Parameter(@Nullable Object value, Class<?> type) {
Assert.notNull(type, "Type must not be null");
this.value = value;
this.type = type;
}
/**
* Create a new {@link Parameter} from {@code value}.
* @param value must not be {@code null}
* @return the {@link Parameter} value for {@code value}
*/
public static Parameter from(Object value) {
Assert.notNull(value, "Value must not be null");
return new Parameter(value, ClassUtils.getUserClass(value));
}
/**
* Create a new {@link Parameter} from {@code value} and {@code type}.
* @param value can be {@code null}
* @param type must not be {@code null}
* @return the {@link Parameter} value for {@code value}
*/
public static Parameter fromOrEmpty(@Nullable Object value, Class<?> type) {
return value == null ? empty(type) : new Parameter(value, ClassUtils.getUserClass(value));
}
/**
* Create a new empty {@link Parameter} for {@code type}.
* @return the empty {@link Parameter} value for {@code type}
*/
public static Parameter empty(Class<?> type) {
Assert.notNull(type, "Type must not be null");
return new Parameter(null, type);
}
/**
* Returns the column value. Can be {@code null}.
* @return the column value. Can be {@code null}
* @see #hasValue()
*/
@Nullable
public Object getValue() {
return this.value;
}
/**
* Returns the column value type. Must be also present if the {@code value} is {@code null}.
* @return the column value type
*/
public Class<?> getType() {
return this.type;
}
/**
* Returns whether this {@link Parameter} has a value.
* @return whether this {@link Parameter} has a value. {@code false} if {@link #getValue()} is {@code null}
*/
public boolean hasValue() {
return this.value != null;
}
/**
* Returns whether this {@link Parameter} has a empty.
* @return whether this {@link Parameter} is empty. {@code true} if {@link #getValue()} is {@code null}
*/
public boolean isEmpty() {
return this.value == null;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Parameter)) {
return false;
}
Parameter other = (Parameter) o;
return ObjectUtils.nullSafeEquals(this.value, other.value) && ObjectUtils.nullSafeEquals(this.type, other.type);
}
@Override
public int hashCode() {
return Objects.hash(this.value, this.type);
}
@Override
public String toString() {
StringBuffer sb = new StringBuffer();
sb.append("Parameter");
sb.append("[value=").append(this.value);
sb.append(", type=").append(this.type);
sb.append(']');
return sb.toString();
}
}

View File

@ -0,0 +1,145 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.ArrayList;
import java.util.List;
/**
* Holds information about a parsed SQL statement.
*
* @author Thomas Risberg
* @author Juergen Hoeller
* @since 5.3
*/
class ParsedSql {
private String originalSql;
private List<String> parameterNames = new ArrayList<>();
private List<int[]> parameterIndexes = new ArrayList<>();
private int namedParameterCount;
private int unnamedParameterCount;
private int totalParameterCount;
/**
* Create a new instance of the {@link ParsedSql} class.
* @param originalSql the SQL statement that is being (or is to be) parsed
*/
ParsedSql(String originalSql) {
this.originalSql = originalSql;
}
/**
* Return the SQL statement that is being parsed.
*/
String getOriginalSql() {
return this.originalSql;
}
/**
* Add a named parameter parsed from this SQL statement.
* @param parameterName the name of the parameter
* @param startIndex the start index in the original SQL String
* @param endIndex the end index in the original SQL String
*/
void addNamedParameter(String parameterName, int startIndex, int endIndex) {
this.parameterNames.add(parameterName);
this.parameterIndexes.add(new int[] {startIndex, endIndex});
}
/**
* Return all of the parameters (bind variables) in the parsed SQL statement.
* Repeated occurrences of the same parameter name are included here.
*/
List<String> getParameterNames() {
return this.parameterNames;
}
/**
* Return the parameter indexes for the specified parameter.
* @param parameterPosition the position of the parameter
* (as index in the parameter names List)
* @return the start index and end index, combined into
* a int array of length 2
*/
int[] getParameterIndexes(int parameterPosition) {
return this.parameterIndexes.get(parameterPosition);
}
/**
* Set the count of named parameters in the SQL statement.
* Each parameter name counts once; repeated occurrences do not count here.
*/
void setNamedParameterCount(int namedParameterCount) {
this.namedParameterCount = namedParameterCount;
}
/**
* Return the count of named parameters in the SQL statement.
* Each parameter name counts once; repeated occurrences do not count here.
*/
int getNamedParameterCount() {
return this.namedParameterCount;
}
/**
* Set the count of all of the unnamed parameters in the SQL statement.
*/
void setUnnamedParameterCount(int unnamedParameterCount) {
this.unnamedParameterCount = unnamedParameterCount;
}
/**
* Return the count of all of the unnamed parameters in the SQL statement.
*/
int getUnnamedParameterCount() {
return this.unnamedParameterCount;
}
/**
* Set the total count of all of the parameters in the SQL statement.
* Repeated occurrences of the same parameter name do count here.
*/
void setTotalParameterCount(int totalParameterCount) {
this.totalParameterCount = totalParameterCount;
}
/**
* Return the total count of all of the parameters in the SQL statement.
* Repeated occurrences of the same parameter name do count here.
*/
int getTotalParameterCount() {
return this.totalParameterCount;
}
/**
* Exposes the original SQL String.
*/
@Override
public String toString() {
return this.originalSql;
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.function.Supplier;
import org.springframework.r2dbc.core.binding.BindTarget;
/**
* Extension to {@link QueryOperation} for a prepared SQL query
* {@link Supplier} with bound parameters. Contains parameter
* bindings that can be {@link #bindTo bound} bound to a {@link BindTarget}.
* <p>Can be executed with {@link org.springframework.r2dbc.core.DatabaseClient}.
*
* @author Mark Paluch
* @since 5.3
* @param <T> underlying operation source.
* @see org.springframework.r2dbc.core.DatabaseClient#sql(Supplier)
*/
public interface PreparedOperation<T> extends QueryOperation {
/**
* Return the underlying query source.
* @return the query source, such as a statement/criteria object.
*/
T getSource();
/**
* Apply bindings to {@link BindTarget}.
* @param target the target to apply bindings to.
*/
void bindTo(BindTarget target);
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.function.Supplier;
/**
* Interface declaring a query operation that can be represented
* with a query string. This interface is typically implemented
* by classes representing a SQL operation such as {@code SELECT},
* {@code INSERT}, and such.
*
* @author Mark Paluch
* @since 5.3
* @see PreparedOperation
*/
@FunctionalInterface
public interface QueryOperation extends Supplier<String> {
/**
* Returns the string-representation of this operation to
* be used with {@link io.r2dbc.spi.Statement} creation.
* @return the operation as SQL string
* @see io.r2dbc.spi.Connection#createStatement(String)
*/
String toQuery();
@Override
default String get() {
return toQuery();
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* Contract for fetching tabular results.
*
* @author Mark Paluch
* @since 5.3
* @param <T> the row result type
*/
public interface RowsFetchSpec<T> {
/**
* Get exactly zero or one result.
*
* @return a mono emitting one element. {@link Mono#empty()} if no match found.
* Completes with {@code IncorrectResultSizeDataAccessException} if more than one match found
*/
Mono<T> one();
/**
* Get the first or no result.
* @return a mono emitting the first element. {@link Mono#empty()} if no match found
*/
Mono<T> first();
/**
* Get all matching elements.
* @return a flux emitting all results
*/
Flux<T> all();
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2002-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import org.springframework.lang.Nullable;
/**
* Interface to be implemented by objects that can provide SQL strings.
*
* <p>Typically implemented by objects that want to expose the SQL they
* use to create their statements, to allow for better contextual
* information in case of exceptions.
*
* @author Mark Paluch
* @since 5.3
*/
public interface SqlProvider {
/**
* Return the SQL string for this object, i.e.
* typically the SQL used for creating statements.
* @return the SQL string, or {@code null}
*/
@Nullable
String getSql();
}

View File

@ -0,0 +1,61 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Statement;
import org.reactivestreams.Publisher;
import org.springframework.util.Assert;
/**
* Represents a function that filters an {@link ExecuteFunction execute function}.
* <p>The filter is executed when a {@link org.reactivestreams.Subscriber} subscribes
* to the {@link Publisher} returned by the {@link DatabaseClient}.
* <p>StatementFilterFunctions are typically used to specify additional details on
* the Statement objects such as {@code fetchSize} or key generation.
*
* @author Mark Paluch
* @since 5.3
* @see ExecuteFunction
*/
@FunctionalInterface
public interface StatementFilterFunction {
/**
* Apply this filter to the given {@link Statement} and {@link ExecuteFunction}.
* <p>The given {@link ExecuteFunction} represents the next entity in the chain,
* to be invoked via {@link ExecuteFunction#execute(Statement)} invoked} in
* order to proceed with the execution, or not invoked to shortcut the chain.
* @param statement the current {@link Statement}
* @param next the next execute function in the chain
* @return the filtered {@link Result}s.
*/
Publisher<? extends Result> filter(Statement statement, ExecuteFunction next);
/**
* Return a composed filter function that first applies this filter, and then
* applies the given {@code "after"} filter.
* @param afterFilter the filter to apply after this filter
* @return the composed filter.
*/
default StatementFilterFunction andThen(StatementFilterFunction afterFilter) {
Assert.notNull(afterFilter, "StatementFilterFunction must not be null");
return (request, next) -> filter(request, afterRequest -> afterFilter.filter(afterRequest, next));
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Statement;
import org.reactivestreams.Publisher;
/**
* Collection of default {@link StatementFilterFunction}s.
*
* @author Mark Paluch
* @since 5.3
*/
enum StatementFilterFunctions implements StatementFilterFunction {
EMPTY_FILTER;
@Override
public Publisher<? extends Result> filter(Statement statement, ExecuteFunction next) {
return next.execute(statement);
}
/**
* Return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}.
* @return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}.
*/
public static StatementFilterFunction empty() {
return EMPTY_FILTER;
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import reactor.core.publisher.Mono;
/**
* Contract for fetching the number of affected rows.
*
* @author Mark Paluch
* @since 5.3
*/
public interface UpdatedRowsFetchSpec {
/**
* Get the number of updated rows.
* @return a mono emitting the number of updated rows
*/
Mono<Integer> rowsUpdated();
}

View File

@ -0,0 +1,63 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
/**
* Anonymous, index-based bind marker using a static placeholder.
* Instances are bound by the ordinal position ordered by the appearance of
* the placeholder. This implementation creates indexed bind markers using
* an anonymous placeholder that correlates with an index.
*
* <p>Note: Anonymous bind markers are problematic because the have to appear
* in generated SQL in the same order they get generated. This might cause
* challenges in the future with complex generate statements. For example those
* containing subselects which limit the freedom of arranging bind markers.
*
* @author Mark Paluch
* @since 5.3
*/
class AnonymousBindMarkers implements BindMarkers {
private static final AtomicIntegerFieldUpdater<AnonymousBindMarkers> COUNTER_INCREMENTER = AtomicIntegerFieldUpdater
.newUpdater(AnonymousBindMarkers.class, "counter");
private final String placeholder;
// access via COUNTER_INCREMENTER
@SuppressWarnings("unused")
private volatile int counter = 0;
/**
* Create a new {@link AnonymousBindMarkers} instance given {@code placeholder}.
* @param placeholder parameter bind marker
*/
AnonymousBindMarkers(String placeholder) {
this.placeholder = placeholder;
}
@Override
public BindMarker next() {
int index = COUNTER_INCREMENTER.getAndIncrement(this);
return new IndexedBindMarkers.IndexedBindMarker(this.placeholder, index);
}
}

View File

@ -0,0 +1,57 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import io.r2dbc.spi.Statement;
/**
* A bind marker represents a single bindable parameter within a query.
* Bind markers are dialect-specific and provide a
* {@link #getPlaceholder() placeholder} that is used in the actual query.
*
* @author Mark Paluch
* @since 5.3
* @see Statement#bind
* @see BindMarkers
* @see BindMarkersFactory
*/
public interface BindMarker {
/**
* Returns the database-specific placeholder for a given substitution.
*/
String getPlaceholder();
/**
* Bind the given {@code value} to the {@link Statement} using the underlying binding strategy.
*
* @param bindTarget the target to bind the value to
* @param value the actual value. Must not be {@code null}
* Use {@link #bindNull(BindTarget, Class)} for {@code null} values
* @see Statement#bind
*/
void bind(BindTarget bindTarget, Object value);
/**
* Bind a {@code null} value to the {@link Statement} using the underlying binding strategy.
* @param bindTarget the target to bind the value to
* @param valueType value type, must not be {@code null}
* @see Statement#bindNull
*/
void bindNull(BindTarget bindTarget, Class<?> valueType);
}

View File

@ -0,0 +1,54 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
/**
* Bind markers represent placeholders in SQL queries for substitution
* for an actual parameter. Using bind markers allows creating safe queries
* so query strings are not required to contain escaped values but rather
* the driver encodes parameter in the appropriate representation.
*
* <p>{@link BindMarkers} is stateful and can be only used for a single binding
* pass of one or more parameters. It maintains bind indexes/bind parameter names.
*
* @author Mark Paluch
* @since 5.3
* @see BindMarker
* @see BindMarkersFactory
* @see io.r2dbc.spi.Statement#bind
*/
@FunctionalInterface
public interface BindMarkers {
/**
* Create a new {@link BindMarker}.
* @return a new {@link BindMarker}
*/
BindMarker next();
/**
* Create a new {@link BindMarker} that accepts a {@code hint}.
* Implementations are allowed to consider/ignore/filter
* the name hint to create more expressive bind markers.
* @param hint an optional name hint that can be used as part of the bind marker
* @return a new {@link BindMarker}
*/
default BindMarker next(String hint) {
return next();
}
}

View File

@ -0,0 +1,149 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.function.Function;
import org.springframework.util.Assert;
/**
* This class creates new {@link BindMarkers} instances to bind
* parameter to a specific {@link io.r2dbc.spi.Statement}.
*
* <p>Bind markers can be typically represented as placeholder and identifier.
* Placeholders are used within the query to execute so the underlying database
* system can substitute the placeholder with the actual value. Identifiers
* are used in R2DBC drivers to bind a value to a bind marker. Identifiers are
* typically a part of an entire bind marker when using indexed or named bind markers.
*
* @author Mark Paluch
* @since 5.3
* @see BindMarkers
* @see io.r2dbc.spi.Statement
*/
@FunctionalInterface
public interface BindMarkersFactory {
/**
* Create a new {@link BindMarkers} instance.
* @return a new {@link BindMarkers} instance
*/
BindMarkers create();
/**
* Return whether the {@link BindMarkersFactory} uses identifiable
* placeholders.
* @return whether the {@link BindMarkersFactory} uses identifiable
* placeholders. {@code false} if multiple placeholders cannot be
* distinguished by just the {@link BindMarker#getPlaceholder() placeholder}
* identifier.
*/
default boolean identifiablePlaceholders() {
return true;
}
// Static, factory methods
/**
* Create index-based {@link BindMarkers} using indexes to bind parameters.
* Allows customization of the bind marker placeholder {@code prefix} to
* represent the bind marker as placeholder within the query.
* @param prefix bind parameter prefix that is included in
* {@link BindMarker#getPlaceholder()} but not the actual identifier
* @param beginWith the first index to use
* @return a {@link BindMarkersFactory} using {@code prefix} and {@code beginWith}
* @see io.r2dbc.spi.Statement#bindNull(int, Class)
* @see io.r2dbc.spi.Statement#bind(int, Object)
*/
static BindMarkersFactory indexed(String prefix, int beginWith) {
Assert.notNull(prefix, "Prefix must not be null");
return () -> new IndexedBindMarkers(prefix, beginWith);
}
/**
* Create anonymous, index-based bind marker using a static placeholder.
* Instances are bound by the ordinal position ordered by the appearance
* of the placeholder. This implementation creates indexed bind markers
* using an anonymous placeholder that correlates with an index.
* @param placeholder parameter placeholder
* @return a {@link BindMarkersFactory} using {@code placeholder}
* @see io.r2dbc.spi.Statement#bindNull(int, Class)
* @see io.r2dbc.spi.Statement#bind(int, Object)
*/
static BindMarkersFactory anonymous(String placeholder) {
Assert.hasText(placeholder, "Placeholder must not be empty!");
return new BindMarkersFactory() {
@Override
public BindMarkers create() {
return new AnonymousBindMarkers(placeholder);
}
@Override
public boolean identifiablePlaceholders() {
return false;
}
};
}
/**
* Create named {@link BindMarkers} using identifiers to bind parameters.
* Named bind markers can support {@link BindMarkers#next(String) name hints}.
* If no {@link BindMarkers#next(String) hint} is given, named bind markers can
* use a counter or a random value source to generate unique bind markers.
* Allows customization of the bind marker placeholder {@code prefix} and
* {@code namePrefix} to represent the bind marker as placeholder within
* the query.
* @param prefix bind parameter prefix that is included in
* {@link BindMarker#getPlaceholder()} but not the actual identifier
* @param namePrefix prefix for bind marker name that is included in
* {@link BindMarker#getPlaceholder()} and the actual identifier
* @param maxLength maximal length of parameter names when using name hints
* @return a {@link BindMarkersFactory} using {@code prefix} and {@code beginWith}
* @see io.r2dbc.spi.Statement#bindNull(String, Class)
* @see io.r2dbc.spi.Statement#bind(String, Object)
*/
static BindMarkersFactory named(String prefix, String namePrefix, int maxLength) {
return named(prefix, namePrefix, maxLength, Function.identity());
}
/**
* Create named {@link BindMarkers} using identifiers to bind parameters.
* Named bind markers support {@link BindMarkers#next(String) name hints}.
* If no {@link BindMarkers#next(String) hint} is given, named bind markers
* can use a counter or a random value source to generate unique bind markers.
* @param prefix bind parameter prefix that is included in
* {@link BindMarker#getPlaceholder()} but not the actual identifier
* @param namePrefix prefix for bind marker name that is included in
* {@link BindMarker#getPlaceholder()} and the actual identifier
* @param maxLength maximal length of parameter names when using name hints
* @param hintFilterFunction filter {@link Function} to consider
* database-specific limitations in bind marker/variable names such as ASCII chars only
* @return a {@link BindMarkersFactory} using {@code prefix} and {@code beginWith}
* @see io.r2dbc.spi.Statement#bindNull(String, Class)
* @see io.r2dbc.spi.Statement#bind(String, Object)
*/
static BindMarkersFactory named(String prefix, String namePrefix, int maxLength,
Function<String, String> hintFilterFunction) {
Assert.notNull(prefix, "Prefix must not be null");
Assert.notNull(namePrefix, "Index prefix must not be null");
Assert.notNull(hintFilterFunction, "Hint filter function must not be null");
return () -> new NamedBindMarkers(prefix, namePrefix, maxLength, hintFilterFunction);
}
}

View File

@ -0,0 +1,181 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import org.springframework.core.io.support.SpringFactoriesLoader;
import org.springframework.dao.NonTransientDataAccessException;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedCaseInsensitiveMap;
/**
* Resolves a {@link BindMarkersFactory} from a {@link ConnectionFactory} using
* {@link BindMarkerFactoryProvider}. Dialect resolution uses Spring's
* {@link SpringFactoriesLoader spring.factories} to determine available extensions.
*
* @author Mark Paluch
* @since 5.3
* @see BindMarkersFactory
* @see SpringFactoriesLoader
*/
public final class BindMarkersFactoryResolver {
private static final List<BindMarkerFactoryProvider> DETECTORS = SpringFactoriesLoader.loadFactories(
BindMarkerFactoryProvider.class, BindMarkersFactoryResolver.class.getClassLoader());
/**
* Retrieve a {@link BindMarkersFactory} by inspecting {@link ConnectionFactory} and
* its metadata.
*
* @param connectionFactory the connection factory to inspect
* @return the resolved {@link BindMarkersFactory}
* @throws NoBindMarkersFactoryException if no {@link BindMarkersFactory} can be
* resolved
*/
public static BindMarkersFactory resolve(ConnectionFactory connectionFactory) {
for (BindMarkerFactoryProvider detector : DETECTORS) {
BindMarkersFactory bindMarkersFactory = detector.getBindMarkers(
connectionFactory);
if (bindMarkersFactory != null) {
return bindMarkersFactory;
}
}
throw new NoBindMarkersFactoryException(
String.format("Cannot determine a BindMarkersFactory for %s using %s",
connectionFactory.getMetadata().getName(), connectionFactory));
}
// utility constructor.
private BindMarkersFactoryResolver() {
}
/**
* SPI to extend Spring's default R2DBC BindMarkersFactory discovery mechanism.
* Implementations of this interface are discovered through Spring's
* {@link SpringFactoriesLoader} mechanism.
* @see SpringFactoriesLoader
*/
@FunctionalInterface
public interface BindMarkerFactoryProvider {
/**
* Returns a {@link BindMarkersFactory} for a {@link ConnectionFactory}.
*
* @param connectionFactory the connection factory to be used with the
* {@link BindMarkersFactory}.
* @return the {@link BindMarkersFactory} if the {@link BindMarkerFactoryProvider}
* can provide a bind marker factory object, otherwise {@code null}
*/
@Nullable
BindMarkersFactory getBindMarkers(ConnectionFactory connectionFactory);
}
/**
* Exception thrown when {@link BindMarkersFactoryResolver} cannot resolve a
* {@link BindMarkersFactory}.
*/
@SuppressWarnings("serial")
public static class NoBindMarkersFactoryException
extends NonTransientDataAccessException {
/**
* Constructor for NoBindMarkersFactoryException.
*
* @param msg the detail message
*/
public NoBindMarkersFactoryException(String msg) {
super(msg);
}
}
/**
* Built-in bind maker factories. Used typically as last {@link BindMarkerFactoryProvider}
* when other providers register with a higher precedence.
* @see org.springframework.core.Ordered
* @see org.springframework.core.annotation.AnnotationAwareOrderComparator
*/
static class BuiltInBindMarkersFactoryProvider implements BindMarkerFactoryProvider {
private static final Map<String, BindMarkersFactory> BUILTIN = new LinkedCaseInsensitiveMap<>(
Locale.ENGLISH);
static {
BUILTIN.put("H2", BindMarkersFactory.indexed("$", 1));
BUILTIN.put("Microsoft SQL Server", BindMarkersFactory.named("@", "P", 32,
BuiltInBindMarkersFactoryProvider::filterBindMarker));
BUILTIN.put("MySQL", BindMarkersFactory.anonymous("?"));
BUILTIN.put("MariaDB", BindMarkersFactory.anonymous("?"));
BUILTIN.put("PostgreSQL", BindMarkersFactory.indexed("$", 1));
}
@Override
public BindMarkersFactory getBindMarkers(ConnectionFactory connectionFactory) {
ConnectionFactoryMetadata metadata = connectionFactory.getMetadata();
BindMarkersFactory r2dbcDialect = BUILTIN.get(metadata.getName());
if (r2dbcDialect != null) {
return r2dbcDialect;
}
for (String it : BUILTIN.keySet()) {
if (metadata.getName().contains(it)) {
return BUILTIN.get(it);
}
}
return null;
}
private static String filterBindMarker(CharSequence input) {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < input.length(); i++) {
char ch = input.charAt(i);
// ascii letter or digit
if (Character.isLetterOrDigit(ch) && ch < 127) {
builder.append(ch);
}
}
if (builder.length() == 0) {
return "";
}
return "_" + builder.toString();
}
}
}

View File

@ -0,0 +1,57 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
/**
* Target to apply bindings to.
*
* @author Mark Paluch
* @since 5.3
* @see io.r2dbc.spi.Statement#bind
* @see io.r2dbc.spi.Statement#bindNull
*/
public interface BindTarget {
/**
* Bind a value.
* @param identifier the identifier to bind to
* @param value the value to bind
*/
void bind(String identifier, Object value);
/**
* Bind a value to an index. Indexes are zero-based.
* @param index the index to bind to
* @param value the value to bind
*/
void bind(int index, Object value);
/**
* Bind a {@code null} value.
* @param identifier the identifier to bind to
* @param type the type of {@code null} value
*/
void bindNull(String identifier, Class<?> type);
/**
* Bind a {@code null} value.
* @param index the index to bind to
* @param type the type of {@code null} value
*/
void bindNull(int index, Class<?> type);
}

View File

@ -0,0 +1,262 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Spliterator;
import java.util.function.Consumer;
import io.r2dbc.spi.Statement;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Value object representing value and {@code null} bindings
* for a {@link Statement} using {@link BindMarkers}.
* Bindings are typically immutable.
*
* @author Mark Paluch
* @since 5.3
*/
public class Bindings implements Iterable<Bindings.Binding> {
private static final Bindings EMPTY = new Bindings();
private final Map<BindMarker, Binding> bindings;
/**
* Create empty {@link Bindings}.
*/
public Bindings() {
this.bindings = Collections.emptyMap();
}
/**
* Create {@link Bindings} from a {@link Map}.
* @param bindings must not be {@code null}
*/
public Bindings(Collection<Binding> bindings) {
Assert.notNull(bindings, "Bindings must not be null");
Map<BindMarker, Binding> mapping = new LinkedHashMap<>(bindings.size());
bindings.forEach(binding -> mapping.put(binding.getBindMarker(), binding));
this.bindings = mapping;
}
Bindings(Map<BindMarker, Binding> bindings) {
this.bindings = bindings;
}
/**
* Create a new, empty {@link Bindings} object.
*
* @return a new, empty {@link Bindings} object.
*/
public static Bindings empty() {
return EMPTY;
}
protected Map<BindMarker, Binding> getBindings() {
return this.bindings;
}
/**
* Merge this bindings with an other {@link Bindings} object and create a new merged
* {@link Bindings} object.
* @param left the left object to merge with
* @param right the right object to merge with
* @return a new, merged {@link Bindings} object
*/
public static Bindings merge(Bindings left, Bindings right) {
Assert.notNull(left, "Left side Bindings must not be null");
Assert.notNull(right, "Right side Bindings must not be null");
List<Binding> result = new ArrayList<>(
left.getBindings().size() + right.getBindings().size());
result.addAll(left.getBindings().values());
result.addAll(right.getBindings().values());
return new Bindings(result);
}
/**
* Merge this bindings with an other {@link Bindings} object and create a new merged
* {@link Bindings} object.
* @param other the object to merge with
* @return a new, merged {@link Bindings} object
*/
public Bindings and(Bindings other) {
return merge(this, other);
}
/**
* Apply the bindings to a {@link BindTarget}.
* @param bindTarget the target to apply bindings to
*/
public void apply(BindTarget bindTarget) {
Assert.notNull(bindTarget, "BindTarget must not be null");
this.bindings.forEach((marker, binding) -> binding.apply(bindTarget));
}
/**
* Perform the given action for each binding of this {@link Bindings} until all
* bindings have been processed or the action throws an exception. Actions are
* performed in the order of iteration (if an iteration order is specified).
* Exceptions thrown by the action are relayed to the
* @param action the action to be performed for each {@link Binding}
*/
public void forEach(Consumer<? super Binding> action) {
this.bindings.forEach((marker, binding) -> action.accept(binding));
}
@Override
public Iterator<Binding> iterator() {
return this.bindings.values().iterator();
}
@Override
public Spliterator<Binding> spliterator() {
return this.bindings.values().spliterator();
}
/**
* Base class for value objects representing a value or a {@code NULL} binding.
*/
public abstract static class Binding {
private final BindMarker marker;
protected Binding(BindMarker marker) {
this.marker = marker;
}
/**
* Return the associated {@link BindMarker}.
* @return the associated {@link BindMarker}.
*/
public BindMarker getBindMarker() {
return this.marker;
}
/**
* Return whether the binding has a value associated with it.
* @return {@code true} if there is a value present, otherwise {@code false}
* for a {@code NULL} binding.
*/
public abstract boolean hasValue();
/**
* Return whether the binding is empty.
* @return {@code true} if this is is a {@code NULL} binding
*/
public boolean isNull() {
return !hasValue();
}
/**
* Return the binding value.
* @return value of this binding. Can be {@code null}
* if this is a {@code NULL} binding.
*/
@Nullable
public abstract Object getValue();
/**
* Apply the binding to a {@link BindTarget}.
* @param bindTarget the target to apply bindings to
*/
public abstract void apply(BindTarget bindTarget);
}
/**
* Value binding.
*/
static class ValueBinding extends Binding {
private final Object value;
ValueBinding(BindMarker marker, Object value) {
super(marker);
this.value = value;
}
@Override
public boolean hasValue() {
return true;
}
@Override
public Object getValue() {
return this.value;
}
@Override
public void apply(BindTarget bindTarget) {
getBindMarker().bind(bindTarget, getValue());
}
}
/**
* {@code NULL} binding.
*/
static class NullBinding extends Binding {
private final Class<?> valueType;
NullBinding(BindMarker marker, Class<?> valueType) {
super(marker);
this.valueType = valueType;
}
@Override
public boolean hasValue() {
return false;
}
@Override
@Nullable
public Object getValue() {
return null;
}
public Class<?> getValueType() {
return this.valueType;
}
@Override
public void apply(BindTarget bindTarget) {
getBindMarker().bindNull(bindTarget, getValueType());
}
}
}

View File

@ -0,0 +1,100 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
/**
* Index-based bind marker. This implementation creates indexed bind
* markers using a numeric index and an optional prefix for bind markers
* to be represented within the query string.
* @author Mark Paluch
* @author Jens Schauder
* @since 5.3
*/
class IndexedBindMarkers implements BindMarkers {
private static final AtomicIntegerFieldUpdater<IndexedBindMarkers> COUNTER_INCREMENTER = AtomicIntegerFieldUpdater
.newUpdater(IndexedBindMarkers.class, "counter");
private final int offset;
private final String prefix;
// access via COUNTER_INCREMENTER
@SuppressWarnings("unused")
private volatile int counter;
/**
* Create a new {@link IndexedBindMarker} instance given {@code prefix} and {@code beginWith}.
* @param prefix bind parameter prefix
* @param beginWith the first index to use
*/
IndexedBindMarkers(String prefix, int beginWith) {
this.counter = 0;
this.prefix = prefix;
this.offset = beginWith;
}
@Override
public BindMarker next() {
int index = COUNTER_INCREMENTER.getAndIncrement(this);
return new IndexedBindMarker(this.prefix + "" + (index + this.offset), index);
}
/**
* A single indexed bind marker.
* @author Mark Paluch
*/
static class IndexedBindMarker implements BindMarker {
private final String placeholder;
private final int index;
IndexedBindMarker(String placeholder, int index) {
this.placeholder = placeholder;
this.index = index;
}
@Override
public String getPlaceholder() {
return this.placeholder;
}
@Override
public void bind(BindTarget target, Object value) {
target.bind(this.index, value);
}
@Override
public void bindNull(BindTarget target, Class<?> valueType) {
target.bindNull(this.index, valueType);
}
public int getIndex() {
return this.index;
}
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.LinkedHashMap;
import io.r2dbc.spi.Statement;
import org.springframework.util.Assert;
/**
* Mutable extension to {@link Bindings} for Value and {@code null} bindings
* for a {@link Statement} using {@link BindMarkers}.
*
* @author Mark Paluch
* @since 5.3
*/
public class MutableBindings extends Bindings {
private final BindMarkers markers;
/**
* Create new {@link MutableBindings}.
* @param markers must not be {@code null}.
*/
public MutableBindings(BindMarkers markers) {
super(new LinkedHashMap<>());
Assert.notNull(markers, "BindMarkers must not be null");
this.markers = markers;
}
/**
* Obtain the next {@link BindMarker}.
* Increments {@link BindMarkers} state
* @return the next {@link BindMarker}
*/
public BindMarker nextMarker() {
return this.markers.next();
}
/**
* Obtain the next {@link BindMarker} with a name {@code hint}.
* Increments {@link BindMarkers} state.
* @param hint name hint
* @return the next {@link BindMarker}
*/
public BindMarker nextMarker(String hint) {
return this.markers.next(hint);
}
/**
* Bind a value to {@link BindMarker}.
* @param marker must not be {@code null}
* @param value must not be {@code null}
*/
public MutableBindings bind(BindMarker marker, Object value) {
Assert.notNull(marker, "BindMarker must not be null");
Assert.notNull(value, "Value must not be null");
getBindings().put(marker, new ValueBinding(marker, value));
return this;
}
/**
* Bind a value and return the related {@link BindMarker}.
* Increments {@link BindMarkers} state.
* @param value must not be {@code null}
*/
public BindMarker bind(Object value) {
Assert.notNull(value, "Value must not be null");
BindMarker marker = nextMarker();
getBindings().put(marker, new ValueBinding(marker, value));
return marker;
}
/**
* Bind a {@code NULL} value to {@link BindMarker}.
* @param marker must not be {@code null}
* @param valueType must not be {@code null}
*/
public MutableBindings bindNull(BindMarker marker, Class<?> valueType) {
Assert.notNull(marker, "BindMarker must not be null");
Assert.notNull(valueType, "Value type must not be null");
getBindings().put(marker, new NullBinding(marker, valueType));
return this;
}
/**
* Bind a {@code NULL} value and return the related {@link BindMarker}.
* Increments {@link BindMarkers} state.
* @param valueType must not be {@code null}
*/
public BindMarker bindNull(Class<?> valueType) {
Assert.notNull(valueType, "Value type must not be null");
BindMarker marker = nextMarker();
getBindings().put(marker, new NullBinding(marker, valueType));
return marker;
}
}

View File

@ -0,0 +1,113 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.function.Function;
import org.springframework.util.Assert;
/**
* Name-based bind markers.
*
* @author Mark Paluch
* @since 5.3
*/
class NamedBindMarkers implements BindMarkers {
private static final AtomicIntegerFieldUpdater<NamedBindMarkers> COUNTER_INCREMENTER = AtomicIntegerFieldUpdater
.newUpdater(NamedBindMarkers.class, "counter");
private final String prefix;
private final String namePrefix;
private final int nameLimit;
private final Function<String, String> hintFilterFunction;
// access via COUNTER_INCREMENTER
@SuppressWarnings("unused")
private volatile int counter;
NamedBindMarkers(String prefix, String namePrefix, int nameLimit, Function<String, String> hintFilterFunction) {
this.prefix = prefix;
this.namePrefix = namePrefix;
this.nameLimit = nameLimit;
this.hintFilterFunction = hintFilterFunction;
}
@Override
public BindMarker next() {
String name = nextName();
return new NamedBindMarker(this.prefix + name, name);
}
@Override
public BindMarker next(String hint) {
Assert.notNull(hint, "Name hint must not be null");
String name = nextName() + this.hintFilterFunction.apply(hint);
if (name.length() > this.nameLimit) {
name = name.substring(0, this.nameLimit);
}
return new NamedBindMarker(this.prefix + name, name);
}
private String nextName() {
int index = COUNTER_INCREMENTER.getAndIncrement(this);
return this.namePrefix + index;
}
/**
* A single named bind marker.
*/
static class NamedBindMarker implements BindMarker {
private final String placeholder;
private final String identifier;
NamedBindMarker(String placeholder, String identifier) {
this.placeholder = placeholder;
this.identifier = identifier;
}
@Override
public String getPlaceholder() {
return this.placeholder;
}
@Override
public void bind(BindTarget target, Object value) {
target.bind(this.identifier, value);
}
@Override
public void bindNull(BindTarget target, Class<?> valueType) {
target.bindNull(this.identifier, valueType);
}
}
}

View File

@ -0,0 +1,9 @@
/**
* Classes providing an abstraction over SQL bind markers.
*/
@NonNullApi
@NonNullFields
package org.springframework.r2dbc.core.binding;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@ -0,0 +1,6 @@
/**
* Core domain types around DatabaseClient.
*/
@org.springframework.lang.NonNullApi
@org.springframework.lang.NonNullFields
package org.springframework.r2dbc.core;

View File

@ -0,0 +1,21 @@
/**
* The classes in this package make R2DBC easier to use and
* reduce the likelihood of common errors. In particular, they:
* <ul>
* <li>Simplify error handling, avoiding the need for resource management
* blocks in application code.
* <li>Present exceptions to application code in a generic hierarchy of
* unchecked exceptions, enabling applications to catch data access
* exceptions without being dependent on R2DBC, and to ignore fatal
* exceptions there is no value in catching.
* <li>Allow the implementation of error handling to be modified
* to target different RDBMSes without introducing proprietary
* dependencies into application code.
* </ul>
*/
@NonNullApi
@NonNullFields
package org.springframework.r2dbc;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@ -0,0 +1,46 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core
import kotlinx.coroutines.reactive.awaitFirstOrNull
/**
* Coroutines variant of [DatabaseClient.GenericExecuteSpec.then].
*
* @author Sebastien Deleuze
*/
suspend fun DatabaseClient.GenericExecuteSpec.await() {
then().awaitFirstOrNull()
}
/**
* Extension for [DatabaseClient.BindSpec.bind] providing a variant leveraging reified type parameters
*
* @author Mark Paluch
* @author Ibanga Enoobong Ime
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
inline fun <reified T : Any> DatabaseClient.GenericExecuteSpec.bind(index: Int, value: T?) = bind(index, Parameter.fromOrEmpty(value, T::class.java))
/**
* Extension for [DatabaseClient.BindSpec.bind] providing a variant leveraging reified type parameters
*
* @author Mark Paluch
* @author Ibanga Enoobong Ime
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
inline fun <reified T : Any> DatabaseClient.GenericExecuteSpec.bind(name: String, value: T?) = bind(name, Parameter.fromOrEmpty(value, T::class.java))

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.reactive.asFlow
import kotlinx.coroutines.reactive.awaitFirstOrNull
import org.springframework.dao.EmptyResultDataAccessException
/**
* Non-nullable Coroutines variant of [RowsFetchSpec.one].
*
* @author Sebastien Deleuze
*/
suspend fun <T> RowsFetchSpec<T>.awaitOne(): T {
return one().awaitFirstOrNull() ?: throw EmptyResultDataAccessException(1)
}
/**
* Nullable Coroutines variant of [RowsFetchSpec.one].
*
* @author Sebastien Deleuze
*/
suspend fun <T> RowsFetchSpec<T>.awaitOneOrNull(): T? =
one().awaitFirstOrNull()
/**
* Non-nullable Coroutines variant of [RowsFetchSpec.first].
*
* @author Sebastien Deleuze
*/
suspend fun <T> RowsFetchSpec<T>.awaitFirst(): T {
return first().awaitFirstOrNull() ?: throw EmptyResultDataAccessException(1)
}
/**
* Nullable Coroutines variant of [RowsFetchSpec.first].
*
* @author Sebastien Deleuze
*/
suspend fun <T> RowsFetchSpec<T>.awaitFirstOrNull(): T? =
first().awaitFirstOrNull()
/**
* Coroutines [Flow] variant of [RowsFetchSpec.all].
*
* @author Sebastien Deleuze
*/
fun <T : Any> RowsFetchSpec<T>.flow(): Flow<T> = all().asFlow()

View File

@ -0,0 +1,27 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core
import kotlinx.coroutines.reactive.awaitSingle
/**
* Coroutines variant of [UpdatedRowsFetchSpec.rowsUpdated].
*
* @author Fred Montariol
*/
suspend fun UpdatedRowsFetchSpec.awaitRowsUpdated(): Int =
rowsUpdated().awaitSingle()

View File

@ -0,0 +1 @@
org.springframework.r2dbc.core.binding.BindMarkersFactoryResolver$BindMarkerFactoryProvider=org.springframework.r2dbc.core.binding.BindMarkersFactoryResolver.BuiltInBindMarkersFactoryProvider

View File

@ -0,0 +1,134 @@
/*
* Copyright 2019-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import io.r2dbc.spi.R2dbcBadGrammarException;
import io.r2dbc.spi.R2dbcDataIntegrityViolationException;
import io.r2dbc.spi.R2dbcException;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import io.r2dbc.spi.R2dbcRollbackException;
import io.r2dbc.spi.R2dbcTimeoutException;
import io.r2dbc.spi.R2dbcTransientResourceException;
import org.junit.jupiter.api.Test;
import org.springframework.dao.ConcurrencyFailureException;
import org.springframework.dao.DataAccessResourceFailureException;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.dao.PermissionDeniedDataAccessException;
import org.springframework.dao.QueryTimeoutException;
import org.springframework.dao.TransientDataAccessResourceException;
import org.springframework.r2dbc.BadSqlGrammarException;
import org.springframework.r2dbc.UncategorizedR2dbcException;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Unit tests for {@link ConnectionFactoryUtils}.
*
* @author Mark Paluch
*/
public class ConnectionFactoryUtilsUnitTests {
@Test
public void shouldTranslateTransientResourceException() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new R2dbcTransientResourceException(""));
assertThat(exception).isInstanceOf(TransientDataAccessResourceException.class);
}
@Test
public void shouldTranslateRollbackException() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new R2dbcRollbackException());
assertThat(exception).isInstanceOf(ConcurrencyFailureException.class);
}
@Test
public void shouldTranslateTimeoutException() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new R2dbcTimeoutException());
assertThat(exception).isInstanceOf(QueryTimeoutException.class);
}
@Test
public void shouldNotTranslateUnknownExceptions() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new MyTransientExceptions());
assertThat(exception).isInstanceOf(UncategorizedR2dbcException.class);
}
@Test
public void shouldTranslateNonTransientResourceException() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new R2dbcNonTransientResourceException());
assertThat(exception).isInstanceOf(DataAccessResourceFailureException.class);
}
@Test
public void shouldTranslateIntegrityViolationException() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new R2dbcDataIntegrityViolationException());
assertThat(exception).isInstanceOf(DataIntegrityViolationException.class);
}
@Test
public void shouldTranslatePermissionDeniedException() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new R2dbcPermissionDeniedException());
assertThat(exception).isInstanceOf(PermissionDeniedDataAccessException.class);
}
@Test
public void shouldTranslateBadSqlGrammarException() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "",
new R2dbcBadGrammarException());
assertThat(exception).isInstanceOf(BadSqlGrammarException.class);
}
@Test
public void messageGeneration() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("TASK",
"SOME-SQL", new R2dbcTransientResourceException("MESSAGE"));
assertThat(exception).isInstanceOf(
TransientDataAccessResourceException.class).hasMessage(
"TASK; SQL [SOME-SQL]; MESSAGE; nested exception is io.r2dbc.spi.R2dbcTransientResourceException: MESSAGE");
}
@Test
public void messageGenerationNullSQL() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("TASK", null,
new R2dbcTransientResourceException("MESSAGE"));
assertThat(exception).isInstanceOf(
TransientDataAccessResourceException.class).hasMessage(
"TASK; MESSAGE; nested exception is io.r2dbc.spi.R2dbcTransientResourceException: MESSAGE");
}
@Test
public void messageGenerationNullMessage() {
Exception exception = ConnectionFactoryUtils.convertR2dbcException("TASK",
"SOME-SQL", new R2dbcTransientResourceException());
assertThat(exception).isInstanceOf(
TransientDataAccessResourceException.class).hasMessage(
"TASK; SQL [SOME-SQL]; null; nested exception is io.r2dbc.spi.R2dbcTransientResourceException");
}
@SuppressWarnings("serial")
private static class MyTransientExceptions extends R2dbcException {
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link DelegatingConnectionFactory}.
*
* @author Mark Paluch
*/
public class DelegatingConnectionFactoryUnitTests {
ConnectionFactory delegate = mock(ConnectionFactory.class);
Connection connectionMock = mock(Connection.class);
DelegatingConnectionFactory connectionFactory = new ExampleConnectionFactory(
delegate);
@Test
public void shouldDelegateGetConnection() {
Mono<Connection> connectionMono = Mono.just(connectionMock);
when(delegate.create()).thenReturn((Mono) connectionMono);
assertThat(connectionFactory.create()).isSameAs(connectionMono);
}
@Test
public void shouldDelegateUnwrapWithoutImplementing() {
assertThat(connectionFactory.unwrap()).isSameAs(delegate);
}
static class ExampleConnectionFactory extends DelegatingConnectionFactory {
ExampleConnectionFactory(ConnectionFactory targetConnectionFactory) {
super(targetConnectionFactory);
}
}
}

View File

@ -0,0 +1,488 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import java.util.concurrent.atomic.AtomicInteger;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.IsolationLevel;
import io.r2dbc.spi.R2dbcBadGrammarException;
import io.r2dbc.spi.Statement;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.transaction.CannotCreateTransactionException;
import org.springframework.transaction.IllegalTransactionStateException;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.reactive.TransactionSynchronization;
import org.springframework.transaction.reactive.TransactionSynchronizationManager;
import org.springframework.transaction.reactive.TransactionalOperator;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.never;
import static org.mockito.BDDMockito.reset;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.verifyNoMoreInteractions;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link R2dbcTransactionManager}.
*
* @author Mark Paluch
*/
public class R2dbcTransactionManagerUnitTests {
ConnectionFactory connectionFactoryMock = mock(ConnectionFactory.class);
Connection connectionMock = mock(Connection.class);
private R2dbcTransactionManager tm;
@BeforeEach
public void before() {
when(connectionFactoryMock.create()).thenReturn((Mono) Mono.just(connectionMock));
when(connectionMock.beginTransaction()).thenReturn(Mono.empty());
when(connectionMock.close()).thenReturn(Mono.empty());
tm = new R2dbcTransactionManager(connectionFactoryMock);
}
@Test
public void testSimpleTransaction() {
TestTransactionSynchronization sync = new TestTransactionSynchronization(
TransactionSynchronization.STATUS_COMMITTED);
AtomicInteger commits = new AtomicInteger();
when(connectionMock.commitTransaction()).thenReturn(
Mono.fromRunnable(commits::incrementAndGet));
TransactionalOperator operator = TransactionalOperator.create(tm);
ConnectionFactoryUtils.getConnection(connectionFactoryMock)
.flatMap(connection -> TransactionSynchronizationManager.forCurrentTransaction()
.doOnNext(synchronizationManager -> synchronizationManager.registerSynchronization(
sync)))
.as(operator::transactional)
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
assertThat(commits).hasValue(1);
verify(connectionMock).isAutoCommit();
verify(connectionMock).beginTransaction();
verify(connectionMock).commitTransaction();
verify(connectionMock).close();
verifyNoMoreInteractions(connectionMock);
assertThat(sync.beforeCommitCalled).isTrue();
assertThat(sync.afterCommitCalled).isTrue();
assertThat(sync.beforeCompletionCalled).isTrue();
assertThat(sync.afterCompletionCalled).isTrue();
}
@Test
public void testBeginFails() {
reset(connectionFactoryMock);
when(connectionFactoryMock.create()).thenReturn(
Mono.error(new R2dbcBadGrammarException("fail")));
when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty());
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
definition.setIsolationLevel(TransactionDefinition.ISOLATION_SERIALIZABLE);
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
ConnectionFactoryUtils.getConnection(connectionFactoryMock).as(
operator::transactional)
.as(StepVerifier::create)
.expectErrorSatisfies(actual -> assertThat(actual).isInstanceOf(
CannotCreateTransactionException.class).hasCauseInstanceOf(
R2dbcBadGrammarException.class))
.verify();
}
@Test
public void appliesIsolationLevel() {
when(connectionMock.commitTransaction()).thenReturn(Mono.empty());
when(connectionMock.getTransactionIsolationLevel()).thenReturn(
IsolationLevel.READ_COMMITTED);
when(connectionMock.setTransactionIsolationLevel(any())).thenReturn(Mono.empty());
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
definition.setIsolationLevel(TransactionDefinition.ISOLATION_SERIALIZABLE);
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
ConnectionFactoryUtils.getConnection(connectionFactoryMock).as(
operator::transactional)
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
verify(connectionMock).beginTransaction();
verify(connectionMock).setTransactionIsolationLevel(
IsolationLevel.READ_COMMITTED);
verify(connectionMock).setTransactionIsolationLevel(IsolationLevel.SERIALIZABLE);
verify(connectionMock).commitTransaction();
verify(connectionMock).close();
}
@Test
public void doesNotSetIsolationLevelIfMatch() {
when(connectionMock.getTransactionIsolationLevel()).thenReturn(
IsolationLevel.READ_COMMITTED);
when(connectionMock.commitTransaction()).thenReturn(Mono.empty());
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
definition.setIsolationLevel(TransactionDefinition.ISOLATION_READ_COMMITTED);
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
ConnectionFactoryUtils.getConnection(connectionFactoryMock).as(
operator::transactional)
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
verify(connectionMock).beginTransaction();
verify(connectionMock, never()).setTransactionIsolationLevel(any());
verify(connectionMock).commitTransaction();
}
@Test
public void doesNotSetAutoCommitDisabled() {
when(connectionMock.isAutoCommit()).thenReturn(false);
when(connectionMock.commitTransaction()).thenReturn(Mono.empty());
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
ConnectionFactoryUtils.getConnection(connectionFactoryMock).as(
operator::transactional)
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
verify(connectionMock).beginTransaction();
verify(connectionMock, never()).setAutoCommit(anyBoolean());
verify(connectionMock).commitTransaction();
}
@Test
public void restoresAutoCommit() {
when(connectionMock.isAutoCommit()).thenReturn(true);
when(connectionMock.setAutoCommit(anyBoolean())).thenReturn(Mono.empty());
when(connectionMock.commitTransaction()).thenReturn(Mono.empty());
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
ConnectionFactoryUtils.getConnection(connectionFactoryMock).as(
operator::transactional)
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
verify(connectionMock).beginTransaction();
verify(connectionMock).setAutoCommit(false);
verify(connectionMock).setAutoCommit(true);
verify(connectionMock).commitTransaction();
verify(connectionMock).close();
}
@Test
public void appliesReadOnly() {
when(connectionMock.commitTransaction()).thenReturn(Mono.empty());
when(connectionMock.setTransactionIsolationLevel(any())).thenReturn(Mono.empty());
Statement statement = mock(Statement.class);
when(connectionMock.createStatement(anyString())).thenReturn(statement);
when(statement.execute()).thenReturn(Mono.empty());
tm.setEnforceReadOnly(true);
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
definition.setReadOnly(true);
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
ConnectionFactoryUtils.getConnection(connectionFactoryMock).as(
operator::transactional)
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
verify(connectionMock).isAutoCommit();
verify(connectionMock).beginTransaction();
verify(connectionMock).createStatement("SET TRANSACTION READ ONLY");
verify(connectionMock).commitTransaction();
verify(connectionMock).close();
verifyNoMoreInteractions(connectionMock);
}
@Test
public void testCommitFails() {
when(connectionMock.commitTransaction()).thenReturn(Mono.defer(() -> Mono.error(new R2dbcBadGrammarException("Commit should fail"))));
when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty());
TransactionalOperator operator = TransactionalOperator.create(tm);
ConnectionFactoryUtils.getConnection(connectionFactoryMock)
.doOnNext(connection -> connection.createStatement("foo")).then()
.as(operator::transactional)
.as(StepVerifier::create)
.verifyError(IllegalTransactionStateException.class);
verify(connectionMock).isAutoCommit();
verify(connectionMock).beginTransaction();
verify(connectionMock).createStatement("foo");
verify(connectionMock).commitTransaction();
verify(connectionMock).close();
verifyNoMoreInteractions(connectionMock);
}
@Test
public void testRollback() {
AtomicInteger commits = new AtomicInteger();
when(connectionMock.commitTransaction()).thenReturn(
Mono.fromRunnable(commits::incrementAndGet));
AtomicInteger rollbacks = new AtomicInteger();
when(connectionMock.rollbackTransaction()).thenReturn(
Mono.fromRunnable(rollbacks::incrementAndGet));
TransactionalOperator operator = TransactionalOperator.create(tm);
ConnectionFactoryUtils.getConnection(connectionFactoryMock)
.doOnNext(connection -> {
throw new IllegalStateException();
}).as(operator::transactional)
.as(StepVerifier::create)
.verifyError(IllegalStateException.class);
assertThat(commits).hasValue(0);
assertThat(rollbacks).hasValue(1);
verify(connectionMock).isAutoCommit();
verify(connectionMock).beginTransaction();
verify(connectionMock).rollbackTransaction();
verify(connectionMock).close();
verifyNoMoreInteractions(connectionMock);
}
@Test
public void testRollbackFails() {
when(connectionMock.rollbackTransaction()).thenReturn(Mono.defer(() -> Mono.error(new R2dbcBadGrammarException("Commit should fail"))), Mono.empty());
TransactionalOperator operator = TransactionalOperator.create(tm);
operator.execute(reactiveTransaction -> {
reactiveTransaction.setRollbackOnly();
return ConnectionFactoryUtils.getConnection(connectionFactoryMock)
.doOnNext(connection -> connection.createStatement("foo")).then();
}).as(StepVerifier::create)
.verifyError(IllegalTransactionStateException.class);
verify(connectionMock).isAutoCommit();
verify(connectionMock).beginTransaction();
verify(connectionMock).createStatement("foo");
verify(connectionMock, never()).commitTransaction();
verify(connectionMock).rollbackTransaction();
verify(connectionMock).close();
verifyNoMoreInteractions(connectionMock);
}
@Test
public void testTransactionSetRollbackOnly() {
when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty());
TestTransactionSynchronization sync = new TestTransactionSynchronization(
TransactionSynchronization.STATUS_ROLLED_BACK);
TransactionalOperator operator = TransactionalOperator.create(tm);
operator.execute(tx -> {
tx.setRollbackOnly();
assertThat(tx.isNewTransaction()).isTrue();
return TransactionSynchronizationManager.forCurrentTransaction().doOnNext(
synchronizationManager -> {
assertThat(synchronizationManager.hasResource(connectionFactoryMock)).isTrue();
synchronizationManager.registerSynchronization(sync);
}).then();
}).as(StepVerifier::create)
.verifyComplete();
verify(connectionMock).isAutoCommit();
verify(connectionMock).beginTransaction();
verify(connectionMock).rollbackTransaction();
verify(connectionMock).close();
verifyNoMoreInteractions(connectionMock);
assertThat(sync.beforeCommitCalled).isFalse();
assertThat(sync.afterCommitCalled).isFalse();
assertThat(sync.beforeCompletionCalled).isTrue();
assertThat(sync.afterCompletionCalled).isTrue();
}
@Test
public void testPropagationNeverWithExistingTransaction() {
when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty());
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
operator.execute(tx1 -> {
assertThat(tx1.isNewTransaction()).isTrue();
definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NEVER);
return operator.execute(tx2 -> {
fail("Should have thrown IllegalTransactionStateException");
return Mono.empty();
});
}).as(StepVerifier::create)
.verifyError(IllegalTransactionStateException.class);
verify(connectionMock).rollbackTransaction();
verify(connectionMock).close();
}
@Test
public void testPropagationSupportsAndRequiresNew() {
when(connectionMock.commitTransaction()).thenReturn(Mono.empty());
DefaultTransactionDefinition definition = new DefaultTransactionDefinition();
definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS);
TransactionalOperator operator = TransactionalOperator.create(tm, definition);
operator.execute(tx1 -> {
assertThat(tx1.isNewTransaction()).isFalse();
DefaultTransactionDefinition innerDef = new DefaultTransactionDefinition();
innerDef.setPropagationBehavior(
TransactionDefinition.PROPAGATION_REQUIRES_NEW);
TransactionalOperator inner = TransactionalOperator.create(tm, innerDef);
return inner.execute(tx2 -> {
assertThat(tx2.isNewTransaction()).isTrue();
return Mono.empty();
});
}).as(StepVerifier::create)
.verifyComplete();
verify(connectionMock).commitTransaction();
verify(connectionMock).close();
}
private static class TestTransactionSynchronization
implements TransactionSynchronization {
private int status;
public boolean beforeCommitCalled;
public boolean beforeCompletionCalled;
public boolean afterCommitCalled;
public boolean afterCompletionCalled;
public Throwable afterCompletionException;
public TestTransactionSynchronization(int status) {
this.status = status;
}
@Override
public Mono<Void> suspend() {
return Mono.empty();
}
@Override
public Mono<Void> resume() {
return Mono.empty();
}
@Override
public Mono<Void> beforeCommit(boolean readOnly) {
if (this.status != TransactionSynchronization.STATUS_COMMITTED) {
fail("Should never be called");
}
return Mono.fromRunnable(() -> {
assertThat(this.beforeCommitCalled).isFalse();
this.beforeCommitCalled = true;
});
}
@Override
public Mono<Void> beforeCompletion() {
return Mono.fromRunnable(() -> {
assertThat(this.beforeCompletionCalled).isFalse();
this.beforeCompletionCalled = true;
});
}
@Override
public Mono<Void> afterCommit() {
if (this.status != TransactionSynchronization.STATUS_COMMITTED) {
fail("Should never be called");
}
return Mono.fromRunnable(() -> {
assertThat(this.afterCommitCalled).isFalse();
this.afterCommitCalled = true;
});
}
@Override
public Mono<Void> afterCompletion(int status) {
try {
return Mono.fromRunnable(() -> doAfterCompletion(status));
}
catch (Throwable ex) {
this.afterCompletionException = ex;
}
return Mono.empty();
}
protected void doAfterCompletion(int status) {
assertThat(this.afterCompletionCalled).isFalse();
this.afterCompletionCalled = true;
assertThat(status).isEqualTo(this.status);
}
}
}

View File

@ -0,0 +1,144 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import io.r2dbc.h2.H2Connection;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import io.r2dbc.spi.IsolationLevel;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import io.r2dbc.spi.Wrapped;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.never;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link SingleConnectionFactory}.
*
* @author Mark Paluch
*/
public class SingleConnectionFactoryUnitTests {
@Test
public void shouldAllocateSameConnection() {
SingleConnectionFactory factory = new SingleConnectionFactory(
"r2dbc:h2:mem:///foo", false);
Mono<? extends Connection> cf1 = factory.create();
Mono<? extends Connection> cf2 = factory.create();
Connection c1 = cf1.block();
Connection c2 = cf2.block();
assertThat(c1).isSameAs(c2);
factory.destroy();
}
@Test
public void shouldApplyAutoCommit() {
SingleConnectionFactory factory = new SingleConnectionFactory(
"r2dbc:h2:mem:///foo", false);
factory.setAutoCommit(false);
factory.create().as(StepVerifier::create)
.consumeNextWith(actual -> assertThat(actual.isAutoCommit()).isFalse())
.verifyComplete();
factory.setAutoCommit(true);
factory.create().as(StepVerifier::create)
.consumeNextWith(actual -> assertThat(actual.isAutoCommit()).isTrue())
.verifyComplete();
factory.destroy();
}
@Test
public void shouldSuppressClose() {
SingleConnectionFactory factory = new SingleConnectionFactory(
"r2dbc:h2:mem:///foo", true);
Connection connection = factory.create().block();
StepVerifier.create(connection.close()).verifyComplete();
assertThat(connection).isInstanceOf(Wrapped.class);
assertThat(((Wrapped) connection).unwrap()).isInstanceOf(H2Connection.class);
StepVerifier.create(
connection.setTransactionIsolationLevel(IsolationLevel.READ_COMMITTED))
.verifyComplete();
factory.destroy();
}
@Test
public void shouldNotSuppressClose() {
SingleConnectionFactory factory = new SingleConnectionFactory(
"r2dbc:h2:mem:///foo", false);
Connection connection = factory.create().block();
StepVerifier.create(connection.close()).verifyComplete();
StepVerifier.create(connection.setTransactionIsolationLevel(
IsolationLevel.READ_COMMITTED)).verifyError(
R2dbcNonTransientResourceException.class);
factory.destroy();
}
@Test
public void releaseConnectionShouldNotCloseConnection() {
Connection connectionMock = mock(Connection.class);
ConnectionFactoryMetadata metadata = mock(ConnectionFactoryMetadata.class);
SingleConnectionFactory factory = new SingleConnectionFactory(
connectionMock, metadata, true);
Connection connection = factory.create().block();
ConnectionFactoryUtils.releaseConnection(connection, factory)
.as(StepVerifier::create)
.verifyComplete();
verify(connectionMock, never()).close();
}
@Test
public void releaseConnectionShouldCloseUnrelatedConnection() {
Connection connectionMock = mock(Connection.class);
Connection otherConnection = mock(Connection.class);
ConnectionFactoryMetadata metadata = mock(ConnectionFactoryMetadata.class);
when(otherConnection.close()).thenReturn(Mono.empty());
SingleConnectionFactory factory = new SingleConnectionFactory(
connectionMock, metadata, false);
factory.create().as(StepVerifier::create).expectNextCount(1).verifyComplete();
ConnectionFactoryUtils.releaseConnection(otherConnection, factory)
.as(StepVerifier::create)
.verifyComplete();
verify(otherConnection).close();
}
}

View File

@ -0,0 +1,155 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection;
import java.util.concurrent.atomic.AtomicReference;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Wrapped;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.transaction.reactive.TransactionalOperator;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.times;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.verifyNoInteractions;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link TransactionAwareConnectionFactoryProxy}.
*
* @author Mark Paluch
* @author Christoph Strobl
*/
public class TransactionAwareConnectionFactoryProxyUnitTests {
ConnectionFactory connectionFactoryMock = mock(ConnectionFactory.class);
Connection connectionMock1 = mock(Connection.class);
Connection connectionMock2 = mock(Connection.class);
Connection connectionMock3 = mock(Connection.class);
R2dbcTransactionManager tm;
@BeforeEach
public void before() {
when(connectionFactoryMock.create()).thenReturn((Mono) Mono.just(connectionMock1),
(Mono) Mono.just(connectionMock2), (Mono) Mono.just(connectionMock3));
tm = new R2dbcTransactionManager(connectionFactoryMock);
}
@Test
public void createShouldWrapConnection() {
new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create()
.as(StepVerifier::create)
.consumeNextWith(connection -> assertThat(connection).isInstanceOf(Wrapped.class))
.verifyComplete();
}
@Test
public void unwrapShouldReturnTargetConnection() {
new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create()
.map(Wrapped.class::cast).as(StepVerifier::create)
.consumeNextWith(wrapped -> assertThat(wrapped.unwrap()).isEqualTo(connectionMock1))
.verifyComplete();
}
@Test
public void unwrapShouldReturnTargetConnectionEvenWhenClosed() {
when(connectionMock1.close()).thenReturn(Mono.empty());
new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create()
.map(Connection.class::cast).flatMap(
connection -> Mono.from(connection.close()).then(Mono.just(connection))).as(
StepVerifier::create)
.consumeNextWith(wrapped -> assertThat(((Wrapped<?>) wrapped).unwrap()).isEqualTo(connectionMock1))
.verifyComplete();
}
@Test
public void getTargetConnectionShouldReturnTargetConnection() {
new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create()
.map(Wrapped.class::cast).as(StepVerifier::create)
.consumeNextWith(wrapped -> assertThat(wrapped.unwrap()).isEqualTo(connectionMock1))
.verifyComplete();
}
@Test
public void getMetadataShouldThrowsErrorEvenWhenClosed() {
when(connectionMock1.close()).thenReturn(Mono.empty());
new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create()
.map(Connection.class::cast).flatMap(
connection -> Mono.from(connection.close())
.then(Mono.just(connection))).as(StepVerifier::create)
.consumeNextWith(connection -> assertThatIllegalStateException().isThrownBy(
connection::getMetadata)).verifyComplete();
}
@Test
public void hashCodeShouldReturnProxyHash() {
new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create()
.map(Connection.class::cast).as(StepVerifier::create)
.consumeNextWith(connection -> assertThat(connection.hashCode()).isEqualTo(
System.identityHashCode(connection))).verifyComplete();
}
@Test
public void equalsShouldCompareCorrectly() {
new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create()
.map(Connection.class::cast).as(StepVerifier::create)
.consumeNextWith(connection -> {
assertThat(connection.equals(connection)).isTrue();
assertThat(connection.equals(connectionMock1)).isFalse();
}).verifyComplete();
}
@Test
public void shouldEmitBoundConnection() {
when(connectionMock1.beginTransaction()).thenReturn(Mono.empty());
when(connectionMock1.commitTransaction()).thenReturn(Mono.empty());
when(connectionMock1.close()).thenReturn(Mono.empty());
TransactionalOperator rxtx = TransactionalOperator.create(tm);
AtomicReference<Connection> transactionalConnection = new AtomicReference<>();
TransactionAwareConnectionFactoryProxy proxyCf = new TransactionAwareConnectionFactoryProxy(
connectionFactoryMock);
ConnectionFactoryUtils.getConnection(connectionFactoryMock)
.doOnNext(transactionalConnection::set).flatMap(connection -> proxyCf.create()
.doOnNext(wrappedConnection -> assertThat(((Wrapped<?>) wrappedConnection).unwrap()).isSameAs(connection)))
.as(rxtx::transactional)
.flatMapMany(Connection::close)
.as(StepVerifier::create)
.verifyComplete();
verifyNoInteractions(connectionMock2);
verifyNoInteractions(connectionMock3);
verify(connectionFactoryMock, times(1)).create();
}
}

View File

@ -0,0 +1,136 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import io.r2dbc.spi.ConnectionFactory;
import org.junit.jupiter.api.Test;
import reactor.test.StepVerifier;
import org.springframework.core.io.ClassRelativeResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.r2dbc.core.DatabaseClient;
/**
* Abstract test support for {@link DatabasePopulator}.
*
* @author Mark Paluch
*/
public abstract class AbstractDatabaseInitializationTests {
ClassRelativeResourceLoader resourceLoader = new ClassRelativeResourceLoader(
getClass());
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator();
@Test
public void scriptWithSingleLineCommentsAndFailedDrop() {
databasePopulator.addScript(resource("db-schema-failed-drop-comments.sql"));
databasePopulator.addScript(resource("db-test-data.sql"));
databasePopulator.setIgnoreFailedDrops(true);
runPopulator();
assertUsersDatabaseCreated("Heisenberg");
}
private void runPopulator() {
databasePopulator.populate(getConnectionFactory()) //
.as(StepVerifier::create) //
.verifyComplete();
}
@Test
public void scriptWithStandardEscapedLiteral() {
databasePopulator.addScript(defaultSchema());
databasePopulator.addScript(resource("db-test-data-escaped-literal.sql"));
runPopulator();
assertUsersDatabaseCreated("'Heisenberg'");
}
@Test
public void scriptWithMySqlEscapedLiteral() {
databasePopulator.addScript(defaultSchema());
databasePopulator.addScript(resource("db-test-data-mysql-escaped-literal.sql"));
runPopulator();
assertUsersDatabaseCreated("\\$Heisenberg\\$");
}
@Test
public void scriptWithMultipleStatements() {
databasePopulator.addScript(defaultSchema());
databasePopulator.addScript(resource("db-test-data-multiple.sql"));
runPopulator();
assertUsersDatabaseCreated("Heisenberg", "Jesse");
}
@Test
public void scriptWithMultipleStatementsAndLongSeparator() {
databasePopulator.addScript(defaultSchema());
databasePopulator.addScript(resource("db-test-data-endings.sql"));
databasePopulator.setSeparator("@@");
runPopulator();
assertUsersDatabaseCreated("Heisenberg", "Jesse");
}
abstract ConnectionFactory getConnectionFactory();
Resource resource(String path) {
return resourceLoader.getResource(path);
}
Resource defaultSchema() {
return resource("db-schema.sql");
}
Resource usersSchema() {
return resource("users-schema.sql");
}
void assertUsersDatabaseCreated(String... lastNames) {
assertUsersDatabaseCreated(getConnectionFactory(), lastNames);
}
void assertUsersDatabaseCreated(ConnectionFactory connectionFactory,
String... lastNames) {
DatabaseClient client = DatabaseClient.create(connectionFactory);
for (String lastName : lastNames) {
client.sql("select count(0) from users where last_name = :name") //
.bind("name", lastName) //
.map((row, metadata) -> row.get(0)) //
.first() //
.map(number -> ((Number) number).intValue()) //
.as(StepVerifier::create) //
.expectNext(1).as(
"Did not find user with last name [" + lastName + "].") //
.verifyComplete();
}
}
}

View File

@ -0,0 +1,113 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import java.util.LinkedHashSet;
import java.util.Set;
import io.r2dbc.spi.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.times;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link CompositeDatabasePopulator}.
*
* @author Mark Paluch
*/
public class CompositeDatabasePopulatorTests {
Connection mockedConnection = mock(Connection.class);
DatabasePopulator mockedDatabasePopulator1 = mock(DatabasePopulator.class);
DatabasePopulator mockedDatabasePopulator2 = mock(DatabasePopulator.class);
@BeforeEach
public void before() {
when(mockedDatabasePopulator1.populate(mockedConnection)).thenReturn(
Mono.empty());
when(mockedDatabasePopulator2.populate(mockedConnection)).thenReturn(
Mono.empty());
}
@Test
public void addPopulators() {
CompositeDatabasePopulator populator = new CompositeDatabasePopulator();
populator.addPopulators(mockedDatabasePopulator1, mockedDatabasePopulator2);
populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete();
verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection);
verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection);
}
@Test
public void setPopulatorsWithMultiple() {
CompositeDatabasePopulator populator = new CompositeDatabasePopulator();
populator.setPopulators(mockedDatabasePopulator1, mockedDatabasePopulator2); // multiple
populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete();
verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection);
verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection);
}
@Test
public void setPopulatorsForOverride() {
CompositeDatabasePopulator populator = new CompositeDatabasePopulator();
populator.setPopulators(mockedDatabasePopulator1);
populator.setPopulators(mockedDatabasePopulator2); // override
populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete();
verify(mockedDatabasePopulator1, times(0)).populate(mockedConnection);
verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection);
}
@Test
public void constructWithVarargs() {
CompositeDatabasePopulator populator = new CompositeDatabasePopulator(
mockedDatabasePopulator1, mockedDatabasePopulator2);
populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete();
verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection);
verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection);
}
@Test
public void constructWithCollection() {
Set<DatabasePopulator> populators = new LinkedHashSet<>();
populators.add(mockedDatabasePopulator1);
populators.add(mockedDatabasePopulator2);
CompositeDatabasePopulator populator = new CompositeDatabasePopulator(populators);
populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete();
verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection);
verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection);
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import java.util.concurrent.atomic.AtomicBoolean;
import io.r2dbc.spi.test.MockConnection;
import io.r2dbc.spi.test.MockConnectionFactory;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link ConnectionFactoryInitializer}.
*
* @author Mark Paluch
*/
public class ConnectionFactoryInitializerUnitTests {
AtomicBoolean called = new AtomicBoolean();
DatabasePopulator populator = mock(DatabasePopulator.class);
MockConnection connection = MockConnection.builder().build();
MockConnectionFactory connectionFactory = MockConnectionFactory.builder().connection(
connection).build();
@Test
public void shouldInitializeConnectionFactory() {
when(populator.populate(connectionFactory)).thenReturn(
Mono.<Void> empty().doOnSubscribe(subscription -> called.set(true)));
ConnectionFactoryInitializer initializer = new ConnectionFactoryInitializer();
initializer.setConnectionFactory(connectionFactory);
initializer.setDatabasePopulator(populator);
initializer.afterPropertiesSet();
assertThat(called).isTrue();
}
@Test
public void shouldCleanConnectionFactory() {
when(populator.populate(connectionFactory)).thenReturn(
Mono.<Void> empty().doOnSubscribe(subscription -> called.set(true)));
ConnectionFactoryInitializer initializer = new ConnectionFactoryInitializer();
initializer.setConnectionFactory(connectionFactory);
initializer.setDatabaseCleaner(populator);
initializer.afterPropertiesSet();
initializer.destroy();
assertThat(called).isTrue();
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import java.util.UUID;
import io.r2dbc.spi.ConnectionFactories;
import io.r2dbc.spi.ConnectionFactory;
import org.junit.jupiter.api.Test;
import reactor.test.StepVerifier;
/**
* Integration tests for {@link DatabasePopulator} using H2.
*
* @author Mark Paluch
*/
public class H2DatabasePopulatorIntegrationTests
extends AbstractDatabaseInitializationTests {
UUID databaseName = UUID.randomUUID();
ConnectionFactory connectionFactory = ConnectionFactories.get("r2dbc:h2:mem:///"
+ databaseName + "?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE");
@Override
ConnectionFactory getConnectionFactory() {
return this.connectionFactory;
}
@Test
public void shouldRunScript() {
databasePopulator.addScript(usersSchema());
databasePopulator.addScript(resource("db-test-data-h2.sql"));
// Set statement separator to double newline so that ";" is not
// considered a statement separator within the source code of the
// aliased function 'REVERSE'.
databasePopulator.setSeparator("\n\n");
databasePopulator.populate(connectionFactory).as(
StepVerifier::create).verifyComplete();
assertUsersDatabaseCreated(connectionFactory, "White");
}
}

View File

@ -0,0 +1,117 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import org.junit.jupiter.api.Test;
import org.springframework.core.io.Resource;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.mock;
/**
* Unit tests for {@link ResourceDatabasePopulator}.
*
* @author Mark Paluch
*/
public class ResourceDatabasePopulatorUnitTests {
private static final Resource script1 = mock(Resource.class);
private static final Resource script2 = mock(Resource.class);
private static final Resource script3 = mock(Resource.class);
@Test
public void constructWithNullResource() {
assertThatIllegalArgumentException().isThrownBy(
() -> new ResourceDatabasePopulator((Resource) null));
}
@Test
public void constructWithNullResourceArray() {
assertThatIllegalArgumentException().isThrownBy(
() -> new ResourceDatabasePopulator((Resource[]) null));
}
@Test
public void constructWithResource() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(
script1);
assertThat(databasePopulator.scripts).hasSize(1);
}
@Test
public void constructWithMultipleResources() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(
script1, script2);
assertThat(databasePopulator.scripts).hasSize(2);
}
@Test
public void constructWithMultipleResourcesAndThenAddScript() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(
script1, script2);
assertThat(databasePopulator.scripts).hasSize(2);
databasePopulator.addScript(script3);
assertThat(databasePopulator.scripts).hasSize(3);
}
@Test
public void addScriptsWithNullResource() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator();
assertThatIllegalArgumentException().isThrownBy(
() -> databasePopulator.addScripts((Resource) null));
}
@Test
public void addScriptsWithNullResourceArray() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator();
assertThatIllegalArgumentException().isThrownBy(
() -> databasePopulator.addScripts((Resource[]) null));
}
@Test
public void setScriptsWithNullResource() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator();
assertThatIllegalArgumentException().isThrownBy(
() -> databasePopulator.setScripts((Resource) null));
}
@Test
public void setScriptsWithNullResourceArray() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator();
assertThatIllegalArgumentException().isThrownBy(
() -> databasePopulator.setScripts((Resource[]) null));
}
@Test
public void setScriptsAndThenAddScript() {
ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator();
assertThat(databasePopulator.scripts).isEmpty();
databasePopulator.setScripts(script1, script2);
assertThat(databasePopulator.scripts).hasSize(2);
databasePopulator.addScript(script3);
assertThat(databasePopulator.scripts).hasSize(3);
}
}

View File

@ -0,0 +1,219 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.init;
import java.util.ArrayList;
import java.util.List;
import org.assertj.core.util.Strings;
import org.junit.jupiter.api.Test;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.io.support.EncodedResource;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Unit tests for {@link ScriptUtils}.
*
* @author Thomas Risberg
* @author Sam Brannen
* @author Phillip Webb
* @author Chris Baldwin
* @author Nicolas Debeissat
* @author Mark Paluch
*/
public class ScriptUtilsUnitTests {
@Test
public void splitSqlScriptDelimitedWithSemicolon() {
String rawStatement1 = "insert into customer (id, name)\nvalues (1, 'Rod ; Johnson'), (2, 'Adrian \n Collier')";
String cleanedStatement1 = "insert into customer (id, name) values (1, 'Rod ; Johnson'), (2, 'Adrian \n Collier')";
String rawStatement2 = "insert into orders(id, order_date, customer_id)\nvalues (1, '2008-01-02', 2)";
String cleanedStatement2 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)";
String rawStatement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)";
String cleanedStatement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)";
String script = Strings.join(rawStatement1, rawStatement2, rawStatement3).with(
";");
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, ";", statements);
assertThat(statements).hasSize(3).containsSequence(cleanedStatement1,
cleanedStatement2, cleanedStatement3);
}
@Test
public void splitSqlScriptDelimitedWithNewLine() {
String statement1 = "insert into customer (id, name) values (1, 'Rod ; Johnson'), (2, 'Adrian \n Collier')";
String statement2 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)";
String statement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)";
String script = Strings.join(statement1, statement2, statement3).with("\n");
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, "\n", statements);
assertThat(statements).hasSize(3).containsSequence(statement1, statement2,
statement3);
}
@Test
public void splitSqlScriptDelimitedWithNewLineButDefaultDelimiterSpecified() {
String statement1 = "do something";
String statement2 = "do something else";
char delim = '\n';
String script = statement1 + delim + statement2 + delim;
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, ScriptUtils.DEFAULT_STATEMENT_SEPARATOR,
statements);
assertThat(statements).hasSize(1).contains(script.replace('\n', ' '));
}
@Test
public void splitScriptWithSingleQuotesNestedInsideDoubleQuotes() {
String statement1 = "select '1' as \"Dogbert's owner's\" from dual";
String statement2 = "select '2' as \"Dilbert's\" from dual";
char delim = ';';
String script = statement1 + delim + statement2 + delim;
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, ';', statements);
assertThat(statements).hasSize(2).containsSequence(statement1, statement2);
}
@Test
public void readAndSplitScriptWithMultipleNewlinesAsSeparator() {
String script = readScript("db-test-data-multi-newline.sql");
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, "\n\n", statements);
String statement1 = "insert into users (last_name) values ('Walter')";
String statement2 = "insert into users (last_name) values ('Jesse')";
assertThat(statements.size()).as("wrong number of statements").isEqualTo(2);
assertThat(statements.get(0)).as("statement 1 not split correctly").isEqualTo(
statement1);
assertThat(statements.get(1)).as("statement 2 not split correctly").isEqualTo(
statement2);
}
@Test
public void readAndSplitScriptContainingComments() {
String script = readScript("test-data-with-comments.sql");
splitScriptContainingComments(script);
}
@Test
public void readAndSplitScriptContainingCommentsWithWindowsLineEnding() {
String script = readScript("test-data-with-comments.sql").replaceAll("\n",
"\r\n");
splitScriptContainingComments(script);
}
private void splitScriptContainingComments(String script) {
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, ';', statements);
String statement1 = "insert into customer (id, name) values (1, 'Rod; Johnson'), (2, 'Adrian Collier')";
String statement2 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)";
String statement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)";
String statement4 = "INSERT INTO persons( person_id , name) VALUES( 1 , 'Name' )";
assertThat(statements).hasSize(4).containsSequence(statement1, statement2,
statement3, statement4);
}
@Test
public void readAndSplitScriptContainingCommentsWithLeadingTabs() {
String script = readScript("test-data-with-comments-and-leading-tabs.sql");
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, ';', statements);
String statement1 = "insert into customer (id, name) values (1, 'Walter White')";
String statement2 = "insert into orders(id, order_date, customer_id) values (1, '2013-06-08', 1)";
String statement3 = "insert into orders(id, order_date, customer_id) values (2, '2013-06-08', 1)";
assertThat(statements).hasSize(3).containsSequence(statement1, statement2,
statement3);
}
@Test
public void readAndSplitScriptContainingMultiLineComments() {
String script = readScript("test-data-with-multi-line-comments.sql");
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, ';', statements);
String statement1 = "INSERT INTO users(first_name, last_name) VALUES('Walter', 'White')";
String statement2 = "INSERT INTO users(first_name, last_name) VALUES( 'Jesse' , 'Pinkman' )";
assertThat(statements).hasSize(2).containsSequence(statement1, statement2);
}
@Test
public void readAndSplitScriptContainingMultiLineNestedComments() {
String script = readScript("test-data-with-multi-line-nested-comments.sql");
List<String> statements = new ArrayList<>();
ScriptUtils.splitSqlScript(script, ';', statements);
String statement1 = "INSERT INTO users(first_name, last_name) VALUES('Walter', 'White')";
String statement2 = "INSERT INTO users(first_name, last_name) VALUES( 'Jesse' , 'Pinkman' )";
assertThat(statements).hasSize(2).containsSequence(statement1, statement2);
}
@Test
public void containsDelimiters() {
assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select ';'",
";")).isFalse();
assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1; select 2",
";")).isTrue();
assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1; select '\\n\n';",
"\n")).isFalse();
assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select 2",
"\n")).isTrue();
assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select 2",
"\n\n")).isFalse();
assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n\n select 2",
"\n\n")).isTrue();
// MySQL style escapes '\\'
assertThat(ScriptUtils.containsSqlScriptDelimiters(
"insert into users(first_name, last_name)\nvalues('a\\\\', 'b;')",
";")).isFalse();
assertThat(ScriptUtils.containsSqlScriptDelimiters(
"insert into users(first_name, last_name)\nvalues('Charles', 'd\\'Artagnan'); select 1;",
";")).isTrue();
}
private String readScript(String path) {
EncodedResource resource = new EncodedResource(
new ClassPathResource(path, getClass()));
return ScriptUtils.readScript(resource, new DefaultDataBufferFactory()).block();
}
}

View File

@ -0,0 +1,195 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import io.r2dbc.spi.ConnectionFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.util.context.Context;
import static java.util.Collections.singletonMap;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link AbstractRoutingConnectionFactory}.
*
* @author Mark Paluch
* @author Jens Schauder
*/
@ExtendWith(MockitoExtension.class)
public class AbstractRoutingConnectionFactoryUnitTests {
private static final String ROUTING_KEY = "routingKey";
@Mock
ConnectionFactory defaultConnectionFactory;
@Mock
ConnectionFactory routedConnectionFactory;
DummyRoutingConnectionFactory connectionFactory;
@BeforeEach
public void before() {
connectionFactory = new DummyRoutingConnectionFactory();
connectionFactory.setDefaultTargetConnectionFactory(defaultConnectionFactory);
}
@Test
public void shouldDetermineRoutedFactory() {
connectionFactory.setTargetConnectionFactories(
singletonMap("key", routedConnectionFactory));
connectionFactory.setConnectionFactoryLookup(new MapConnectionFactoryLookup());
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.subscriberContext(Context.of(ROUTING_KEY, "key"))
.as(StepVerifier::create)
.expectNext(routedConnectionFactory)
.verifyComplete();
}
@Test
public void shouldFallbackToDefaultConnectionFactory() {
connectionFactory.setTargetConnectionFactories(
singletonMap("key", routedConnectionFactory));
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.as(StepVerifier::create)
.expectNext(defaultConnectionFactory)
.verifyComplete();
}
@Test
public void initializationShouldFailUnsupportedLookupKey() {
connectionFactory.setTargetConnectionFactories(singletonMap("key", new Object()));
assertThatThrownBy(() -> connectionFactory.afterPropertiesSet()).isInstanceOf(
IllegalArgumentException.class);
}
@Test
public void initializationShouldFailUnresolvableKey() {
connectionFactory.setTargetConnectionFactories(singletonMap("key", "value"));
connectionFactory.setConnectionFactoryLookup(new MapConnectionFactoryLookup());
assertThatThrownBy(() -> connectionFactory.afterPropertiesSet())
.isInstanceOf(ConnectionFactoryLookupFailureException.class)
.hasMessageContaining(
"No ConnectionFactory with name 'value' registered");
}
@Test
public void unresolvableConnectionFactoryRetrievalShouldFail() {
connectionFactory.setLenientFallback(false);
connectionFactory.setConnectionFactoryLookup(new MapConnectionFactoryLookup());
connectionFactory.setTargetConnectionFactories(
singletonMap("key", routedConnectionFactory));
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.subscriberContext(Context.of(ROUTING_KEY, "unknown"))
.as(StepVerifier::create)
.verifyError(IllegalStateException.class);
}
@Test
public void connectionFactoryRetrievalWithUnknownLookupKeyShouldReturnDefaultConnectionFactory() {
connectionFactory.setTargetConnectionFactories(
singletonMap("key", routedConnectionFactory));
connectionFactory.setDefaultTargetConnectionFactory(defaultConnectionFactory);
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.subscriberContext(Context.of(ROUTING_KEY, "unknown"))
.as(StepVerifier::create)
.expectNext(defaultConnectionFactory)
.verifyComplete();
}
@Test
public void connectionFactoryRetrievalWithoutLookupKeyShouldReturnDefaultConnectionFactory() {
connectionFactory.setTargetConnectionFactories(
singletonMap("key", routedConnectionFactory));
connectionFactory.setDefaultTargetConnectionFactory(defaultConnectionFactory);
connectionFactory.setLenientFallback(false);
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.as(StepVerifier::create)
.expectNext(defaultConnectionFactory)
.verifyComplete();
}
@Test
public void shouldLookupFromMap() {
MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup("lookup-key",
routedConnectionFactory);
connectionFactory.setConnectionFactoryLookup(lookup);
connectionFactory.setTargetConnectionFactories(
singletonMap("my-key", "lookup-key"));
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.subscriberContext(Context.of(ROUTING_KEY, "my-key"))
.as(StepVerifier::create)
.expectNext(routedConnectionFactory)
.verifyComplete();
}
@Test
public void shouldAllowModificationsAfterInitialization() {
MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup();
connectionFactory.setConnectionFactoryLookup(lookup);
connectionFactory.setTargetConnectionFactories(lookup.getConnectionFactories());
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.subscriberContext(Context.of(ROUTING_KEY, "lookup-key"))
.as(StepVerifier::create)
.expectNext(defaultConnectionFactory)
.verifyComplete();
lookup.addConnectionFactory("lookup-key", routedConnectionFactory);
connectionFactory.afterPropertiesSet();
connectionFactory.determineTargetConnectionFactory()
.subscriberContext(Context.of(ROUTING_KEY, "lookup-key"))
.as(StepVerifier::create)
.expectNext(routedConnectionFactory)
.verifyComplete();
}
static class DummyRoutingConnectionFactory extends AbstractRoutingConnectionFactory {
@Override
protected Mono<Object> determineCurrentLookupKey() {
return Mono.subscriberContext().filter(context -> context.hasKey(ROUTING_KEY))
.map(context -> context.get(ROUTING_KEY));
}
}
}

View File

@ -0,0 +1,88 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import io.r2dbc.spi.ConnectionFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanNotOfRequiredTypeException;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link BeanFactoryConnectionFactoryLookup}.
*
* @author Mark Paluch
*/
@ExtendWith(MockitoExtension.class)
public class BeanFactoryConnectionFactoryLookupUnitTests {
private static final String CONNECTION_FACTORY_BEAN_NAME = "connectionFactory";
@Mock
BeanFactory beanFactory;
@Test
public void shouldLookupConnectionFactory() {
DummyConnectionFactory expectedConnectionFactory = new DummyConnectionFactory();
when(beanFactory.getBean(CONNECTION_FACTORY_BEAN_NAME,
ConnectionFactory.class)).thenReturn(expectedConnectionFactory);
BeanFactoryConnectionFactoryLookup lookup = new BeanFactoryConnectionFactoryLookup();
lookup.setBeanFactory(beanFactory);
ConnectionFactory connectionFactory = lookup.getConnectionFactory(
CONNECTION_FACTORY_BEAN_NAME);
assertThat(connectionFactory).isNotNull();
assertThat(connectionFactory).isSameAs(expectedConnectionFactory);
}
@Test
public void shouldLookupWhereBeanFactoryYieldsNonConnectionFactoryType() {
BeanFactory beanFactory = mock(BeanFactory.class);
when(beanFactory.getBean(CONNECTION_FACTORY_BEAN_NAME,
ConnectionFactory.class)).thenThrow(
new BeanNotOfRequiredTypeException(CONNECTION_FACTORY_BEAN_NAME,
ConnectionFactory.class, String.class));
BeanFactoryConnectionFactoryLookup lookup = new BeanFactoryConnectionFactoryLookup(
beanFactory);
assertThatExceptionOfType(
ConnectionFactoryLookupFailureException.class).isThrownBy(
() -> lookup.getConnectionFactory(CONNECTION_FACTORY_BEAN_NAME));
}
@Test
public void shouldLookupWhereBeanFactoryHasNotBeenSupplied() {
BeanFactoryConnectionFactoryLookup lookup = new BeanFactoryConnectionFactoryLookup();
assertThatThrownBy(() -> lookup.getConnectionFactory(
CONNECTION_FACTORY_BEAN_NAME)).isInstanceOf(IllegalStateException.class);
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import org.reactivestreams.Publisher;
/**
* Stub, do-nothing {@link ConnectionFactory} implementation.
* <p>
* All methods throw {@link UnsupportedOperationException}.
*
* @author Mark Paluch
*/
class DummyConnectionFactory implements ConnectionFactory {
@Override
public Publisher<? extends Connection> create() {
throw new UnsupportedOperationException();
}
@Override
public ConnectionFactoryMetadata getMetadata() {
throw new UnsupportedOperationException();
}
}

View File

@ -0,0 +1,102 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.connection.lookup;
import java.util.HashMap;
import java.util.Map;
import io.r2dbc.spi.ConnectionFactory;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link MapConnectionFactoryLookup}.
*
* @author Mark Paluch
*/
public class MapConnectionFactoryLookupUnitTests {
private static final String CONNECTION_FACTORY_NAME = "connectionFactory";
@Test
public void getConnectionFactorysReturnsUnmodifiableMap() {
MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup();
Map<String, ConnectionFactory> connectionFactories = lookup.getConnectionFactories();
assertThatThrownBy(() -> connectionFactories.put("",
new DummyConnectionFactory())).isInstanceOf(
UnsupportedOperationException.class);
}
@Test
public void shouldLookupConnectionFactory() {
Map<String, ConnectionFactory> connectionFactories = new HashMap<>();
DummyConnectionFactory expectedConnectionFactory = new DummyConnectionFactory();
connectionFactories.put(CONNECTION_FACTORY_NAME, expectedConnectionFactory);
MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup();
lookup.setConnectionFactories(connectionFactories);
ConnectionFactory connectionFactory = lookup.getConnectionFactory(
CONNECTION_FACTORY_NAME);
assertThat(connectionFactory).isNotNull().isSameAs(expectedConnectionFactory);
}
@Test
public void addingConnectionFactoryPermitsOverride() {
Map<String, ConnectionFactory> connectionFactories = new HashMap<>();
DummyConnectionFactory overriddenConnectionFactory = new DummyConnectionFactory();
DummyConnectionFactory expectedConnectionFactory = new DummyConnectionFactory();
connectionFactories.put(CONNECTION_FACTORY_NAME, overriddenConnectionFactory);
MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup();
lookup.setConnectionFactories(connectionFactories);
lookup.addConnectionFactory(CONNECTION_FACTORY_NAME, expectedConnectionFactory);
ConnectionFactory connectionFactory = lookup.getConnectionFactory(
CONNECTION_FACTORY_NAME);
assertThat(connectionFactory).isNotNull().isSameAs(expectedConnectionFactory);
}
@Test
@SuppressWarnings({ "unchecked", "rawtypes" })
public void getConnectionFactoryWhereSuppliedMapHasNonConnectionFactoryTypeUnderSpecifiedKey() {
Map connectionFactories = new HashMap<>();
connectionFactories.put(CONNECTION_FACTORY_NAME, new Object());
MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup(
connectionFactories);
assertThatThrownBy(
() -> lookup.getConnectionFactory(CONNECTION_FACTORY_NAME)).isInstanceOf(
ClassCastException.class);
}
@Test
public void getConnectionFactoryWhereSuppliedMapHasNoEntryForSpecifiedKey() {
MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup();
assertThatThrownBy(
() -> lookup.getConnectionFactory(CONNECTION_FACTORY_NAME)).isInstanceOf(
ConnectionFactoryLookupFailureException.class);
}
}

View File

@ -0,0 +1,152 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Result;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.dao.DataIntegrityViolationException;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Integration tests for {@link DatabaseClient}.
*
* @author Mark Paluch
* @author Mingyuan Wu
*/
public abstract class AbstractDatabaseClientIntegrationTests {
private ConnectionFactory connectionFactory;
@BeforeEach
public void before() {
connectionFactory = createConnectionFactory();
Mono.from(connectionFactory.create())
.flatMapMany(connection -> Flux.from(connection.createStatement("DROP TABLE legoset").execute())
.flatMap(Result::getRowsUpdated)
.onErrorResume(e -> Mono.empty())
.thenMany(connection.createStatement(getCreateTableStatement()).execute())
.flatMap(Result::getRowsUpdated).thenMany(connection.close())).as(StepVerifier::create)
.verifyComplete();
}
/**
* Creates a {@link ConnectionFactory} to be used in this test.
*
* @return the {@link ConnectionFactory} to be used in this test
*/
protected abstract ConnectionFactory createConnectionFactory();
/**
* Return the the CREATE TABLE statement for table {@code legoset} with the following
* three columns:
* <ul>
* <li>id integer (primary key), not null</li>
* <li>name varchar(255), nullable</li>
* <li>manual integer, nullable</li>
* </ul>
*
* @return the CREATE TABLE statement for table {@code legoset} with three columns.
*/
protected abstract String getCreateTableStatement();
@Test
public void executeInsert() {
DatabaseClient databaseClient = DatabaseClient.create(connectionFactory);
databaseClient.sql("INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)")
.bind("id", 42055)
.bind("name", "SCHAUFELRADBAGGER")
.bindNull("manual", Integer.class)
.fetch().rowsUpdated()
.as(StepVerifier::create)
.expectNext(1)
.verifyComplete();
databaseClient.sql("SELECT id FROM legoset")
.map(row -> row.get("id"))
.first()
.as(StepVerifier::create)
.assertNext(actual -> {
assertThat(actual).isInstanceOf(Number.class);
assertThat(((Number) actual).intValue()).isEqualTo(42055);
}).verifyComplete();
}
@Test
public void shouldTranslateDuplicateKeyException() {
DatabaseClient databaseClient = DatabaseClient.create(connectionFactory);
executeInsert();
databaseClient.sql(
"INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)")
.bind("id", 42055)
.bind("name", "SCHAUFELRADBAGGER")
.bindNull("manual", Integer.class)
.fetch().rowsUpdated()
.as(StepVerifier::create)
.expectErrorSatisfies(exception -> assertThat(exception)
.isInstanceOf(DataIntegrityViolationException.class)
.hasMessageContaining("execute; SQL [INSERT INTO legoset"))
.verify();
}
@Test
public void executeDeferred() {
DatabaseClient databaseClient = DatabaseClient.create(connectionFactory);
databaseClient.sql(() -> "INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)")
.bind("id", 42055)
.bind("name", "SCHAUFELRADBAGGER")
.bindNull("manual", Integer.class)
.fetch().rowsUpdated()
.as(StepVerifier::create)
.expectNext(1)
.verifyComplete();
databaseClient.sql("SELECT id FROM legoset")
.map(row -> row.get("id")).first()
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
}
@Test
public void shouldEmitGeneratedKey() {
DatabaseClient databaseClient = DatabaseClient.create(connectionFactory);
databaseClient.sql(
"INSERT INTO legoset ( name, manual) VALUES(:name, :manual)")
.bind("name","SCHAUFELRADBAGGER")
.bindNull("manual", Integer.class)
.filter(statement -> statement.returnGeneratedValues("id"))
.map(row -> (Number) row.get("id"))
.first()
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
}
}

View File

@ -0,0 +1,208 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Result;
import org.assertj.core.api.Condition;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.r2dbc.connection.R2dbcTransactionManager;
import org.springframework.transaction.ReactiveTransactionManager;
import org.springframework.transaction.reactive.TransactionalOperator;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Abstract base class for transactional integration tests for {@link DatabaseClient}.
*
* @author Mark Paluch
* @author Christoph Strobl
*/
public abstract class AbstractTransactionalDatabaseClientIntegrationTests {
private ConnectionFactory connectionFactory;
AnnotationConfigApplicationContext context;
DatabaseClient databaseClient;
R2dbcTransactionManager transactionManager;
TransactionalOperator rxtx;
@BeforeEach
public void before() {
connectionFactory = createConnectionFactory();
context = new AnnotationConfigApplicationContext();
context.getBeanFactory().registerResolvableDependency(ConnectionFactory.class, connectionFactory);
context.register(Config.class);
context.refresh();
Mono.from(connectionFactory.create())
.flatMapMany(connection -> Flux.from(connection.createStatement("DROP TABLE legoset").execute())
.flatMap(Result::getRowsUpdated)
.onErrorResume(e -> Mono.empty())
.thenMany(connection.createStatement(getCreateTableStatement()).execute())
.flatMap(Result::getRowsUpdated).thenMany(connection.close())).as(StepVerifier::create).verifyComplete();
databaseClient = DatabaseClient.create(connectionFactory);
transactionManager = new R2dbcTransactionManager(connectionFactory);
rxtx = TransactionalOperator.create(transactionManager);
}
@AfterEach
public void tearDown() {
context.close();
}
/**
* Create a {@link ConnectionFactory} to be used in this test.
* @return the {@link ConnectionFactory} to be used in this test.
*/
protected abstract ConnectionFactory createConnectionFactory();
/**
* Return the the CREATE TABLE statement for table {@code legoset} with the following three columns:
* <ul>
* <li>id integer (primary key), not null</li>
* <li>name varchar(255), nullable</li>
* <li>manual integer, nullable</li>
* </ul>
*
* @return the CREATE TABLE statement for table {@code legoset} with three columns.
*/
protected abstract String getCreateTableStatement();
/**
* Get a parameterized {@code INSERT INTO legoset} statement setting id, name, and manual values.
*/
protected String getInsertIntoLegosetStatement() {
return "INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)";
}
@Test
public void executeInsertInTransaction() {
Flux<Integer> integerFlux = databaseClient
.sql(getInsertIntoLegosetStatement())
.bind(0, 42055)
.bind(1, "SCHAUFELRADBAGGER")
.bindNull(2, Integer.class)
.fetch().rowsUpdated().flux().as(rxtx::transactional);
integerFlux.as(StepVerifier::create)
.expectNext(1)
.verifyComplete();
databaseClient
.sql("SELECT id FROM legoset")
.fetch()
.first()
.as(StepVerifier::create)
.assertNext(actual -> assertThat(actual).hasEntrySatisfying("id", numberOf(42055)))
.verifyComplete();
}
@Test
public void shouldRollbackTransaction() {
Mono<Object> integerFlux = databaseClient.sql(getInsertIntoLegosetStatement())
.bind(0, 42055)
.bind(1, "SCHAUFELRADBAGGER")
.bindNull(2, Integer.class)
.fetch().rowsUpdated()
.then(Mono.error(new IllegalStateException("failed")))
.as(rxtx::transactional);
integerFlux.as(StepVerifier::create)
.expectError(IllegalStateException.class)
.verify();
databaseClient
.sql("SELECT id FROM legoset")
.fetch()
.first()
.as(StepVerifier::create)
.verifyComplete();
}
@Test
public void shouldRollbackTransactionUsingTransactionalOperator() {
DatabaseClient databaseClient = DatabaseClient.create(connectionFactory);
TransactionalOperator transactionalOperator = TransactionalOperator
.create(new R2dbcTransactionManager(connectionFactory), new DefaultTransactionDefinition());
Flux<Integer> integerFlux = databaseClient.sql(getInsertIntoLegosetStatement())
.bind(0, 42055)
.bind(1, "SCHAUFELRADBAGGER")
.bindNull(2, Integer.class)
.fetch().rowsUpdated()
.thenMany(Mono.fromSupplier(() -> {
throw new IllegalStateException("failed");
}));
integerFlux.as(transactionalOperator::transactional)
.as(StepVerifier::create)
.expectError(IllegalStateException.class)
.verify();
databaseClient
.sql("SELECT id FROM legoset")
.fetch()
.first()
.as(StepVerifier::create)
.verifyComplete();
}
private Condition<? super Object> numberOf(int expected) {
return new Condition<>(object -> object instanceof Number &&
((Number) object).intValue() == expected, "Number %d", expected);
}
@Configuration(proxyBeanMethods = false)
static class Config {
@Autowired GenericApplicationContext context;
@Bean
ReactiveTransactionManager txMgr(ConnectionFactory connectionFactory) {
return new R2dbcTransactionManager(connectionFactory);
}
@Bean
TransactionalOperator transactionalOperator(ReactiveTransactionManager transactionManager) {
return TransactionalOperator.create(transactionManager);
}
}
}

View File

@ -0,0 +1,435 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.Arrays;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Statement;
import io.r2dbc.spi.test.MockColumnMetadata;
import io.r2dbc.spi.test.MockResult;
import io.r2dbc.spi.test.MockRow;
import io.r2dbc.spi.test.MockRowMetadata;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.lang.Nullable;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
import org.springframework.r2dbc.core.binding.BindTarget;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.doReturn;
import static org.mockito.BDDMockito.inOrder;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.times;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.verifyNoInteractions;
import static org.mockito.BDDMockito.verifyNoMoreInteractions;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link DefaultDatabaseClient}.
*
* @author Mark Paluch
* @author Ferdinand Jacobs
* @author Jens Schauder
*/
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public class DefaultDatabaseClientUnitTests {
@Mock
Connection connection;
private DatabaseClient.Builder databaseClientBuilder;
@BeforeEach
public void before() {
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
when(connectionFactory.create()).thenReturn((Publisher) Mono.just(connection));
when(connection.close()).thenReturn(Mono.empty());
databaseClientBuilder = DatabaseClient.builder().connectionFactory(
connectionFactory).bindMarkers(BindMarkersFactory.indexed("$", 1));
}
@Test
public void shouldCloseConnectionOnlyOnce() {
DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) databaseClientBuilder.build();
Flux<Object> flux = databaseClient.inConnectionMany(connection -> Flux.empty());
flux.subscribe(new CoreSubscriber<Object>() {
Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
s.request(1);
subscription = s;
}
@Override
public void onNext(Object o) {
}
@Override
public void onError(Throwable t) {
}
@Override
public void onComplete() {
subscription.cancel();
}
});
verify(connection, times(1)).close();
}
@Test
public void executeShouldBindNullValues() {
Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1");
DatabaseClient databaseClient = databaseClientBuilder.namedParameters(false).build();
databaseClient.sql("SELECT * FROM table WHERE key = $1").bindNull(0,
String.class).then().as(StepVerifier::create).verifyComplete();
verify(statement).bindNull(0, String.class);
databaseClient.sql("SELECT * FROM table WHERE key = $1").bindNull("$1",
String.class).then().as(StepVerifier::create).verifyComplete();
verify(statement).bindNull("$1", String.class);
}
@Test
public void executeShouldBindSettableValues() {
Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1");
DatabaseClient databaseClient = databaseClientBuilder.namedParameters(false).build();
databaseClient.sql("SELECT * FROM table WHERE key = $1").bind(0,
Parameter.empty(String.class)).then().as(
StepVerifier::create).verifyComplete();
verify(statement).bindNull(0, String.class);
databaseClient.sql("SELECT * FROM table WHERE key = $1").bind("$1",
Parameter.empty(String.class)).then().as(
StepVerifier::create).verifyComplete();
verify(statement).bindNull("$1", String.class);
}
@Test
public void executeShouldBindNamedNullValues() {
Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1");
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT * FROM table WHERE key = :key").bindNull("key",
String.class).then().as(StepVerifier::create).verifyComplete();
verify(statement).bindNull(0, String.class);
}
@Test
public void executeShouldBindNamedValuesFromIndexes() {
Statement statement = mockStatementFor(
"SELECT id, name, manual FROM legoset WHERE name IN ($1, $2, $3)");
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql(
"SELECT id, name, manual FROM legoset WHERE name IN (:name)").bind(0,
Arrays.asList("unknown", "dunno", "other")).then().as(
StepVerifier::create).verifyComplete();
verify(statement).bind(0, "unknown");
verify(statement).bind(1, "dunno");
verify(statement).bind(2, "other");
verify(statement).execute();
verifyNoMoreInteractions(statement);
}
@Test
public void executeShouldBindValues() {
Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1");
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT * FROM table WHERE key = $1").bind(0,
Parameter.from("foo")).then().as(StepVerifier::create).verifyComplete();
verify(statement).bind(0, "foo");
databaseClient.sql("SELECT * FROM table WHERE key = $1").bind("$1",
"foo").then().as(StepVerifier::create).verifyComplete();
verify(statement).bind("$1", "foo");
}
@Test
public void executeShouldBindNamedValuesByIndex() {
Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1");
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT * FROM table WHERE key = :key").bind("key",
"foo").then().as(StepVerifier::create).verifyComplete();
verify(statement).bind(0, "foo");
}
@Test
public void rowsUpdatedShouldEmitSingleValue() {
Result result = mock(Result.class);
when(result.getRowsUpdated()).thenReturn(Mono.empty(), Mono.just(2),
Flux.just(1, 2, 3));
mockStatementFor("DROP TABLE tab;", result);
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("DROP TABLE tab;").fetch().rowsUpdated().as(
StepVerifier::create).expectNextCount(1).verifyComplete();
databaseClient.sql("DROP TABLE tab;").fetch().rowsUpdated().as(
StepVerifier::create).expectNextCount(1).verifyComplete();
databaseClient.sql("DROP TABLE tab;").fetch().rowsUpdated().as(
StepVerifier::create).expectNextCount(1).verifyComplete();
}
@Test
public void selectShouldEmitFirstValue() {
MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(
MockColumnMetadata.builder().name("name").build()).build();
MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata);
MockResult result = resultBuilder.row(MockRow.builder().identified(0, Object.class, "Walter").build())
.row(MockRow.builder().identified(0, Object.class, "White").build()).build();
mockStatementFor("SELECT * FROM person", result);
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT * FROM person").map(row -> row.get(0))
.first()
.as(StepVerifier::create)
.expectNext("Walter")
.verifyComplete();
}
@Test
public void selectShouldEmitAllValues() {
MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(
MockColumnMetadata.builder().name("name").build()).build();
MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata);
MockResult result = resultBuilder.row(MockRow.builder().identified(0, Object.class, "Walter").build())
.row(MockRow.builder().identified(0, Object.class, "White").build()).build();
mockStatementFor("SELECT * FROM person", result);
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT * FROM person").map(row -> row.get(0))
.all()
.as(StepVerifier::create)
.expectNext("Walter")
.expectNext("White")
.verifyComplete();
}
@Test
public void selectOneShouldFailWithException() {
MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(
MockColumnMetadata.builder().name("name").build()).build();
MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata);
MockResult result = resultBuilder.row(MockRow.builder().identified(0, Object.class, "Walter").build())
.row(MockRow.builder().identified(0, Object.class, "White").build()).build();
mockStatementFor("SELECT * FROM person", result);
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT * FROM person").map(row -> row.get(0))
.one()
.as(StepVerifier::create)
.verifyError(IncorrectResultSizeDataAccessException.class);
}
@Test
public void shouldApplyExecuteFunction() {
Statement statement = mockStatement();
MockResult result = mockSingleColumnResult(
MockRow.builder().identified(0, Object.class, "Walter"));
DatabaseClient databaseClient = databaseClientBuilder.executeFunction(
stmnt -> Mono.just(result)).build();
databaseClient.sql("SELECT").fetch().all().as(
StepVerifier::create).expectNextCount(1).verifyComplete();
verifyNoInteractions(statement);
}
@Test
public void shouldApplyPreparedOperation() {
MockResult result = mockSingleColumnResult(
MockRow.builder().identified(0, Object.class, "Walter"));
Statement statement = mockStatementFor("SELECT * FROM person", result);
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql(new PreparedOperation<String>() {
@Override
public String toQuery() {
return "SELECT * FROM person";
}
@Override
public String getSource() {
return "SELECT";
}
@Override
public void bindTo(BindTarget target) {
target.bind("index", "value");
}
}).fetch().all().as(
StepVerifier::create).expectNextCount(1).verifyComplete();
verify(statement).bind("index", "value");
}
@Test
public void shouldApplyStatementFilterFunctions() {
MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(
MockColumnMetadata.builder().name("name").build()).build();
MockResult result = MockResult.builder().rowMetadata(metadata).build();
Statement statement = mockStatement(result);
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT").filter(
(s, next) -> next.execute(s.returnGeneratedValues("foo"))).filter(
(s, next) -> next.execute(
s.returnGeneratedValues("bar"))).fetch().all().as(
StepVerifier::create).verifyComplete();
InOrder inOrder = inOrder(statement);
inOrder.verify(statement).returnGeneratedValues("foo");
inOrder.verify(statement).returnGeneratedValues("bar");
inOrder.verify(statement).execute();
inOrder.verifyNoMoreInteractions();
}
@Test
public void shouldApplySimpleStatementFilterFunctions() {
MockResult result = mockSingleColumnEmptyResult();
Statement statement = mockStatement(result);
DatabaseClient databaseClient = databaseClientBuilder.build();
databaseClient.sql("SELECT").filter(
s -> s.returnGeneratedValues("foo")).filter(
s -> s.returnGeneratedValues("bar")).fetch().all().as(
StepVerifier::create).verifyComplete();
InOrder inOrder = inOrder(statement);
inOrder.verify(statement).returnGeneratedValues("foo");
inOrder.verify(statement).returnGeneratedValues("bar");
inOrder.verify(statement).execute();
inOrder.verifyNoMoreInteractions();
}
private Statement mockStatement() {
return mockStatementFor(null, null);
}
private Statement mockStatement(Result result) {
return mockStatementFor(null, result);
}
private Statement mockStatementFor(String sql) {
return mockStatementFor(sql, null);
}
private Statement mockStatementFor(@Nullable String sql, @Nullable Result result) {
Statement statement = mock(Statement.class);
when(connection.createStatement(sql == null ? anyString() : eq(sql))).thenReturn(
statement);
when(statement.returnGeneratedValues(anyString())).thenReturn(statement);
when(statement.returnGeneratedValues()).thenReturn(statement);
doReturn(result == null ? Mono.empty() : Flux.just(result)).when(
statement).execute();
return statement;
}
private MockResult mockSingleColumnEmptyResult() {
return mockSingleColumnResult(null);
}
/**
* Mocks a {@link Result} with a single column "name" and a single row if a non null
* row is provided.
*/
private MockResult mockSingleColumnResult(@Nullable MockRow.Builder row) {
MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(
MockColumnMetadata.builder().name("name").build()).build();
MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata);
if (row != null) {
resultBuilder = resultBuilder.row(row.build());
}
return resultBuilder.build();
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import io.r2dbc.h2.H2ConnectionFactory;
import io.r2dbc.spi.ConnectionFactory;
/**
* Integration tests for {@link DatabaseClient} against H2.
*
* @author Mark Paluch
*/
public class H2DatabaseClientIntegrationTests
extends AbstractDatabaseClientIntegrationTests {
public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" //
+ " id serial CONSTRAINT id PRIMARY KEY,\n" //
+ " version integer NULL,\n" //
+ " name varchar(255) NOT NULL,\n" //
+ " manual integer NULL\n" //
+ ");";
@Override
protected ConnectionFactory createConnectionFactory() {
return H2ConnectionFactory.inMemory("r2dbc-test");
}
@Override
protected String getCreateTableStatement() {
return CREATE_TABLE_LEGOSET;
}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import io.r2dbc.h2.H2ConnectionFactory;
import io.r2dbc.spi.ConnectionFactory;
/**
* Integration tests for {@link DatabaseClient} against H2.
*
* @author Mark Paluch
*/
public class H2TransactionalDatabaseClientIntegrationTests
extends AbstractTransactionalDatabaseClientIntegrationTests {
public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" //
+ " id integer CONSTRAINT id PRIMARY KEY,\n" //
+ " version integer NULL,\n" //
+ " name varchar(255) NOT NULL,\n" //
+ " manual integer NULL\n" //
+ ");";
@Override
protected ConnectionFactory createConnectionFactory() {
return H2ConnectionFactory.inMemory("r2dbc-transactional");
}
@Override
protected String getCreateTableStatement() {
return CREATE_TABLE_LEGOSET;
}
}

View File

@ -0,0 +1,462 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
import org.springframework.r2dbc.core.binding.BindTarget;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Unit tests for {@link NamedParameterUtils}.
*
* @author Mark Paluch
* @author Jens Schauder
*/
public class NamedParameterUtilsUnitTests {
private final BindMarkersFactory BIND_MARKERS = BindMarkersFactory.indexed("$", 1);
@Test
public void shouldParseSql() {
String sql = "xxx :a yyyy :b :c :a zzzzz";
ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(psql.getParameterNames()).containsExactly("a", "b", "c", "a");
assertThat(psql.getTotalParameterCount()).isEqualTo(4);
assertThat(psql.getNamedParameterCount()).isEqualTo(3);
String sql2 = "xxx &a yyyy ? zzzzz";
ParsedSql psql2 = NamedParameterUtils.parseSqlStatement(sql2);
assertThat(psql2.getParameterNames()).containsExactly("a");
assertThat(psql2.getTotalParameterCount()).isEqualTo(1);
assertThat(psql2.getNamedParameterCount()).isEqualTo(1);
String sql3 = "xxx &ä+:ö" + '\t' + ":ü%10 yyyy ? zzzzz";
ParsedSql psql3 = NamedParameterUtils.parseSqlStatement(sql3);
assertThat(psql3.getParameterNames()).containsExactly("ä", "ö", "ü");
}
@Test
public void substituteNamedParameters() {
MapBindParameterSource namedParams = new MapBindParameterSource(new HashMap<>());
namedParams.addValue("a", "a").addValue("b", "b").addValue("c", "c");
PreparedOperation<?> operation = NamedParameterUtils.substituteNamedParameters(
"xxx :a :b :c", BIND_MARKERS, namedParams);
assertThat(operation.toQuery()).isEqualTo("xxx $1 $2 $3");
PreparedOperation<?> operation2 = NamedParameterUtils.substituteNamedParameters(
"xxx :a :b :c", BindMarkersFactory.named("@", "P", 8), namedParams);
assertThat(operation2.toQuery()).isEqualTo("xxx @P0a @P1b @P2c");
}
@Test
public void substituteObjectArray() {
MapBindParameterSource namedParams = new MapBindParameterSource(new HashMap<>());
namedParams.addValue("a", Arrays.asList(new Object[] { "Walter", "Heisenberg" },
new Object[] { "Walt Jr.", "Flynn" }));
PreparedOperation<?> operation = NamedParameterUtils.substituteNamedParameters(
"xxx :a", BIND_MARKERS, namedParams);
assertThat(operation.toQuery()).isEqualTo("xxx ($1, $2), ($3, $4)");
}
@Test
public void shouldBindObjectArray() {
MapBindParameterSource namedParams = new MapBindParameterSource(new HashMap<>());
namedParams.addValue("a", Arrays.asList(new Object[] { "Walter", "Heisenberg" },
new Object[] { "Walt Jr.", "Flynn" }));
BindTarget bindTarget = mock(BindTarget.class);
PreparedOperation<?> operation = NamedParameterUtils.substituteNamedParameters(
"xxx :a", BIND_MARKERS, namedParams);
operation.bindTo(bindTarget);
verify(bindTarget).bind(0, "Walter");
verify(bindTarget).bind(1, "Heisenberg");
verify(bindTarget).bind(2, "Walt Jr.");
verify(bindTarget).bind(3, "Flynn");
}
@Test
public void parseSqlContainingComments() {
String sql1 = "/*+ HINT */ xxx /* comment ? */ :a yyyy :b :c :a zzzzz -- :xx XX\n";
ParsedSql psql1 = NamedParameterUtils.parseSqlStatement(sql1);
assertThat(expand(psql1)).isEqualTo(
"/*+ HINT */ xxx /* comment ? */ $1 yyyy $2 $3 $1 zzzzz -- :xx XX\n");
MapBindParameterSource paramMap = new MapBindParameterSource(new HashMap<>());
paramMap.addValue("a", "a");
paramMap.addValue("b", "b");
paramMap.addValue("c", "c");
String sql2 = "/*+ HINT */ xxx /* comment ? */ :a yyyy :b :c :a zzzzz -- :xx XX";
ParsedSql psql2 = NamedParameterUtils.parseSqlStatement(sql2);
assertThat(expand(psql2)).isEqualTo(
"/*+ HINT */ xxx /* comment ? */ $1 yyyy $2 $3 $1 zzzzz -- :xx XX");
}
@Test
public void parseSqlStatementWithPostgresCasting() {
String expectedSql = "select 'first name' from artists where id = $1 and birth_date=$2::timestamp";
String sql = "select 'first name' from artists where id = :id and birth_date=:birthDate::timestamp";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
PreparedOperation<?> operation = NamedParameterUtils.substituteNamedParameters(
parsedSql, BIND_MARKERS, new MapBindParameterSource());
assertThat(operation.toQuery()).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithPostgresContainedOperator() {
String expectedSql = "select 'first name' from artists where info->'stat'->'albums' = ?? $1 and '[\"1\",\"2\",\"3\"]'::jsonb ?? '4'";
String sql = "select 'first name' from artists where info->'stat'->'albums' = ?? :album and '[\"1\",\"2\",\"3\"]'::jsonb ?? '4'";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsedSql.getTotalParameterCount()).isEqualTo(1);
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithPostgresAnyArrayStringsExistsOperator() {
String expectedSql = "select '[\"3\", \"11\"]'::jsonb ?| '{1,3,11,12,17}'::text[]";
String sql = "select '[\"3\", \"11\"]'::jsonb ?| '{1,3,11,12,17}'::text[]";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsedSql.getTotalParameterCount()).isEqualTo(0);
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithPostgresAllArrayStringsExistsOperator() {
String expectedSql = "select '[\"3\", \"11\"]'::jsonb ?& '{1,3,11,12,17}'::text[] AND $1 = 'Back in Black'";
String sql = "select '[\"3\", \"11\"]'::jsonb ?& '{1,3,11,12,17}'::text[] AND :album = 'Back in Black'";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsedSql.getTotalParameterCount()).isEqualTo(1);
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithEscapedColon() {
String expectedSql = "select '0\\:0' as a, foo from bar where baz < DATE($1 23:59:59) and baz = $2";
String sql = "select '0\\:0' as a, foo from bar where baz < DATE(:p1 23\\:59\\:59) and baz = :p2";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsedSql.getParameterNames()).containsExactly("p1", "p2");
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithBracketDelimitedParameterNames() {
String expectedSql = "select foo from bar where baz = b$1$2z";
String sql = "select foo from bar where baz = b:{p1}:{p2}z";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsedSql.getParameterNames()).containsExactly("p1", "p2");
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithEmptyBracketsOrBracketsInQuotes() {
String expectedSql = "select foo from bar where baz = b:{}z";
String sql = "select foo from bar where baz = b:{}z";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsedSql.getParameterNames()).isEmpty();
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
String expectedSql2 = "select foo from bar where baz = 'b:{p1}z'";
String sql2 = "select foo from bar where baz = 'b:{p1}z'";
ParsedSql parsedSql2 = NamedParameterUtils.parseSqlStatement(sql2);
assertThat(parsedSql2.getParameterNames()).isEmpty();
assertThat(expand(parsedSql2)).isEqualTo(expectedSql2);
}
@Test
public void parseSqlStatementWithSingleLetterInBrackets() {
String expectedSql = "select foo from bar where baz = b$1z";
String sql = "select foo from bar where baz = b:{p}z";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsedSql.getParameterNames()).containsExactly("p");
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithLogicalAnd() {
String expectedSql = "xxx & yyyy";
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(expectedSql);
assertThat(expand(parsedSql)).isEqualTo(expectedSql);
}
@Test
public void substituteNamedParametersWithLogicalAnd() {
String expectedSql = "xxx & yyyy";
assertThat(expand(expectedSql)).isEqualTo(expectedSql);
}
@Test
public void variableAssignmentOperator() {
String expectedSql = "x := 1";
assertThat(expand(expectedSql)).isEqualTo(expectedSql);
}
@Test
public void parseSqlStatementWithQuotedSingleQuote() {
String sql = "SELECT ':foo'':doo', :xxx FROM DUAL";
ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(psql.getTotalParameterCount()).isEqualTo(1);
assertThat(psql.getParameterNames()).containsExactly("xxx");
}
@Test
public void parseSqlStatementWithQuotesAndCommentBefore() {
String sql = "SELECT /*:doo*/':foo', :xxx FROM DUAL";
ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql);
assertThat(psql.getTotalParameterCount()).isEqualTo(1);
assertThat(psql.getParameterNames()).containsExactly("xxx");
}
@Test
public void parseSqlStatementWithQuotesAndCommentAfter() {
String sql2 = "SELECT ':foo'/*:doo*/, :xxx FROM DUAL";
ParsedSql psql2 = NamedParameterUtils.parseSqlStatement(sql2);
assertThat(psql2.getTotalParameterCount()).isEqualTo(1);
assertThat(psql2.getParameterNames()).containsExactly("xxx");
}
@Test
public void shouldAllowParsingMultipleUseOfParameter() {
String sql = "SELECT * FROM person where name = :id or lastname = :id";
ParsedSql parsed = NamedParameterUtils.parseSqlStatement(sql);
assertThat(parsed.getTotalParameterCount()).isEqualTo(2);
assertThat(parsed.getNamedParameterCount()).isEqualTo(1);
assertThat(parsed.getParameterNames()).containsExactly("id", "id");
}
@Test
public void multipleEqualParameterReferencesBindsValueOnce() {
String sql = "SELECT * FROM person where name = :id or lastname = :id";
BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0);
PreparedOperation<String> operation = NamedParameterUtils.substituteNamedParameters(
sql, factory, new MapBindParameterSource(
Collections.singletonMap("id", Parameter.from("foo"))));
assertThat(operation.toQuery()).isEqualTo(
"SELECT * FROM person where name = $0 or lastname = $0");
operation.bindTo(new BindTarget() {
@Override
public void bind(String identifier, Object value) {
throw new UnsupportedOperationException();
}
@Override
public void bind(int index, Object value) {
assertThat(index).isEqualTo(0);
assertThat(value).isEqualTo("foo");
}
@Override
public void bindNull(String identifier, Class<?> type) {
throw new UnsupportedOperationException();
}
@Override
public void bindNull(int index, Class<?> type) {
throw new UnsupportedOperationException();
}
});
}
@Test
public void multipleEqualCollectionParameterReferencesBindsValueOnce() {
String sql = "SELECT * FROM person where name IN (:ids) or lastname IN (:ids)";
BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0);
MultiValueMap<Integer, Object> bindings = new LinkedMultiValueMap<>();
PreparedOperation<String> operation = NamedParameterUtils.substituteNamedParameters(
sql, factory, new MapBindParameterSource(Collections.singletonMap("ids",
Parameter.from(Arrays.asList("foo", "bar", "baz")))));
assertThat(operation.toQuery()).isEqualTo(
"SELECT * FROM person where name IN ($0, $1, $2) or lastname IN ($0, $1, $2)");
operation.bindTo(new BindTarget() {
@Override
public void bind(String identifier, Object value) {
throw new UnsupportedOperationException();
}
@Override
public void bind(int index, Object value) {
assertThat(index).isIn(0, 1, 2);
assertThat(value).isIn("foo", "bar", "baz");
bindings.add(index, value);
}
@Override
public void bindNull(String identifier, Class<?> type) {
throw new UnsupportedOperationException();
}
@Override
public void bindNull(int index, Class<?> type) {
throw new UnsupportedOperationException();
}
});
assertThat(bindings).containsEntry(0, Collections.singletonList("foo")) //
.containsEntry(1, Collections.singletonList("bar")) //
.containsEntry(2, Collections.singletonList("baz"));
}
@Test
public void multipleEqualParameterReferencesForAnonymousMarkersBindsValueMultipleTimes() {
String sql = "SELECT * FROM person where name = :id or lastname = :id";
BindMarkersFactory factory = BindMarkersFactory.anonymous("?");
PreparedOperation<String> operation = NamedParameterUtils.substituteNamedParameters(
sql, factory, new MapBindParameterSource(
Collections.singletonMap("id", Parameter.from("foo"))));
assertThat(operation.toQuery()).isEqualTo(
"SELECT * FROM person where name = ? or lastname = ?");
Map<Integer, Object> bindValues = new LinkedHashMap<>();
operation.bindTo(new BindTarget() {
@Override
public void bind(String identifier, Object value) {
throw new UnsupportedOperationException();
}
@Override
public void bind(int index, Object value) {
bindValues.put(index, value);
}
@Override
public void bindNull(String identifier, Class<?> type) {
throw new UnsupportedOperationException();
}
@Override
public void bindNull(int index, Class<?> type) {
throw new UnsupportedOperationException();
}
});
assertThat(bindValues).hasSize(2).containsEntry(0, "foo").containsEntry(1, "foo");
}
@Test
public void multipleEqualParameterReferencesBindsNullOnce() {
String sql = "SELECT * FROM person where name = :id or lastname = :id";
BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0);
PreparedOperation<String> operation = NamedParameterUtils.substituteNamedParameters(
sql, factory, new MapBindParameterSource(
Collections.singletonMap("id", Parameter.empty(String.class))));
assertThat(operation.toQuery()).isEqualTo(
"SELECT * FROM person where name = $0 or lastname = $0");
operation.bindTo(new BindTarget() {
@Override
public void bind(String identifier, Object value) {
throw new UnsupportedOperationException();
}
@Override
public void bind(int index, Object value) {
throw new UnsupportedOperationException();
}
@Override
public void bindNull(String identifier, Class<?> type) {
throw new UnsupportedOperationException();
}
@Override
public void bindNull(int index, Class<?> type) {
assertThat(index).isEqualTo(0);
assertThat(type).isEqualTo(String.class);
}
});
}
private String expand(ParsedSql sql) {
return NamedParameterUtils.substituteNamedParameters(sql, BIND_MARKERS,
new MapBindParameterSource()).toQuery();
}
private String expand(String sql) {
return NamedParameterUtils.substituteNamedParameters(sql, BIND_MARKERS,
new MapBindParameterSource()).toQuery();
}
}

View File

@ -0,0 +1,59 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Unit tests for {@link AnonymousBindMarkers}.
*
* @author Mark Paluch
*/
class AnonymousBindMarkersUnitTests {
@Test
public void shouldCreateNewBindMarkers() {
BindMarkersFactory factory = BindMarkersFactory.anonymous("?");
BindMarkers bindMarkers1 = factory.create();
BindMarkers bindMarkers2 = factory.create();
assertThat(bindMarkers1.next().getPlaceholder()).isEqualTo("?");
assertThat(bindMarkers2.next().getPlaceholder()).isEqualTo("?");
}
@Test
public void shouldBindByIndex() {
BindTarget bindTarget = mock(BindTarget.class);
BindMarkers bindMarkers = BindMarkersFactory.anonymous("?").create();
BindMarker first = bindMarkers.next();
BindMarker second = bindMarkers.next();
second.bind(bindTarget, "foo");
first.bindNull(bindTarget, Object.class);
verify(bindTarget).bindNull(0, Object.class);
verify(bindTarget).bind(1, "foo");
}
}

View File

@ -0,0 +1,146 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Unit tests for {@link Bindings}.
*
* @author Mark Paluch
*/
class BindingsUnitTests {
BindMarkersFactory markersFactory = BindMarkersFactory.indexed("$", 1);
BindTarget bindTarget = mock(BindTarget.class);
@Test
void shouldCreateBindings() {
MutableBindings bindings = new MutableBindings(markersFactory.create());
bindings.bind(bindings.nextMarker(), "foo");
bindings.bindNull(bindings.nextMarker(), String.class);
assertThat(bindings).hasSize(2);
}
@Test
void shouldApplyValueBinding() {
MutableBindings bindings = new MutableBindings(markersFactory.create());
bindings.bind(bindings.nextMarker(), "foo");
bindings.apply(bindTarget);
verify(bindTarget).bind(0, "foo");
}
@Test
void shouldApplySimpleValueBinding() {
MutableBindings bindings = new MutableBindings(markersFactory.create());
BindMarker marker = bindings.bind("foo");
bindings.apply(bindTarget);
assertThat(marker.getPlaceholder()).isEqualTo("$1");
verify(bindTarget).bind(0, "foo");
}
@Test
void shouldApplyNullBinding() {
MutableBindings bindings = new MutableBindings(markersFactory.create());
bindings.bindNull(bindings.nextMarker(), String.class);
bindings.apply(bindTarget);
verify(bindTarget).bindNull(0, String.class);
}
@Test
void shouldApplySimpleNullBinding() {
MutableBindings bindings = new MutableBindings(markersFactory.create());
BindMarker marker = bindings.bindNull(String.class);
bindings.apply(bindTarget);
assertThat(marker.getPlaceholder()).isEqualTo("$1");
verify(bindTarget).bindNull(0, String.class);
}
@Test
void shouldConsumeBindings() {
MutableBindings bindings = new MutableBindings(markersFactory.create());
bindings.bind(bindings.nextMarker(), "foo");
bindings.bindNull(bindings.nextMarker(), String.class);
AtomicInteger counter = new AtomicInteger();
bindings.forEach(binding -> {
if (binding.hasValue()) {
counter.incrementAndGet();
assertThat(binding.getValue()).isEqualTo("foo");
assertThat(binding.getBindMarker().getPlaceholder()).isEqualTo("$1");
}
if (binding.isNull()) {
counter.incrementAndGet();
assertThat(((Bindings.NullBinding) binding).getValueType()).isEqualTo(String.class);
assertThat(binding.getBindMarker().getPlaceholder()).isEqualTo("$2");
}
});
assertThat(counter).hasValue(2);
}
@Test
void shouldMergeBindings() {
BindMarkers markers = markersFactory.create();
BindMarker shared = markers.next();
BindMarker leftMarker = markers.next();
List<Bindings.Binding> left = new ArrayList<>();
left.add(new Bindings.NullBinding(shared, String.class));
left.add(new Bindings.ValueBinding(leftMarker, "left"));
BindMarker rightMarker = markers.next();
List<Bindings.Binding> right = new ArrayList<>();
left.add(new Bindings.ValueBinding(shared, "override"));
left.add(new Bindings.ValueBinding(rightMarker, "right"));
Bindings merged = Bindings.merge(new Bindings(left), new Bindings(right));
assertThat(merged).hasSize(3);
merged.apply(bindTarget);
verify(bindTarget).bind(0, "override");
verify(bindTarget).bind(1, "left");
verify(bindTarget).bind(2, "right");
}
}

View File

@ -0,0 +1,103 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Unit tests for {@link IndexedBindMarkers}.
*
* @author Mark Paluch
*/
class IndexedBindMarkersUnitTests {
@Test
void shouldCreateNewBindMarkers() {
BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0);
BindMarkers bindMarkers1 = factory.create();
BindMarkers bindMarkers2 = factory.create();
assertThat(bindMarkers1.next().getPlaceholder()).isEqualTo("$0");
assertThat(bindMarkers2.next().getPlaceholder()).isEqualTo("$0");
}
@Test
void shouldCreateNewBindMarkersWithOffset() {
BindTarget bindTarget = mock(BindTarget.class);
BindMarkers bindMarkers = BindMarkersFactory.indexed("$", 1).create();
BindMarker first = bindMarkers.next();
first.bind(bindTarget, "foo");
BindMarker second = bindMarkers.next();
second.bind(bindTarget, "bar");
assertThat(first.getPlaceholder()).isEqualTo("$1");
assertThat(second.getPlaceholder()).isEqualTo("$2");
verify(bindTarget).bind(0, "foo");
verify(bindTarget).bind(1, "bar");
}
@Test
void nextShouldIncrementBindMarker() {
String[] prefixes = { "$", "?" };
for (String prefix : prefixes) {
BindMarkers bindMarkers = BindMarkersFactory.indexed(prefix, 0).create();
BindMarker marker1 = bindMarkers.next();
BindMarker marker2 = bindMarkers.next();
assertThat(marker1.getPlaceholder()).isEqualTo(prefix + "0");
assertThat(marker2.getPlaceholder()).isEqualTo(prefix + "1");
}
}
@Test
void bindValueShouldBindByIndex() {
BindTarget bindTarget = mock(BindTarget.class);
BindMarkers bindMarkers = BindMarkersFactory.indexed("$", 0).create();
bindMarkers.next().bind(bindTarget, "foo");
bindMarkers.next().bind(bindTarget, "bar");
verify(bindTarget).bind(0, "foo");
verify(bindTarget).bind(1, "bar");
}
@Test
void bindNullShouldBindByIndex() {
BindTarget bindTarget = mock(BindTarget.class);
BindMarkers bindMarkers = BindMarkersFactory.indexed("$", 0).create();
bindMarkers.next(); // ignore
bindMarkers.next().bindNull(bindTarget, Integer.class);
verify(bindTarget).bindNull(1, Integer.class);
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core.binding;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Unit tests for {@link NamedBindMarkers}.
*
* @author Mark Paluch
*/
class NamedBindMarkersUnitTests {
@Test
void shouldCreateNewBindMarkers() {
BindMarkersFactory factory = BindMarkersFactory.named("@", "p", 32);
BindMarkers bindMarkers1 = factory.create();
BindMarkers bindMarkers2 = factory.create();
assertThat(bindMarkers1.next().getPlaceholder()).isEqualTo("@p0");
assertThat(bindMarkers2.next().getPlaceholder()).isEqualTo("@p0");
}
@ParameterizedTest
@ValueSource(strings = { "$", "?" })
void nextShouldIncrementBindMarker(String prefix) {
BindMarkers bindMarkers = BindMarkersFactory.named(prefix, "p", 32).create();
BindMarker marker1 = bindMarkers.next();
BindMarker marker2 = bindMarkers.next();
assertThat(marker1.getPlaceholder()).isEqualTo(prefix + "p0");
assertThat(marker2.getPlaceholder()).isEqualTo(prefix + "p1");
}
@Test
void nextShouldConsiderNameHint() {
BindMarkers bindMarkers = BindMarkersFactory.named("@", "x", 32).create();
BindMarker marker1 = bindMarkers.next("foo1bar");
BindMarker marker2 = bindMarkers.next();
assertThat(marker1.getPlaceholder()).isEqualTo("@x0foo1bar");
assertThat(marker2.getPlaceholder()).isEqualTo("@x1");
}
@Test
void nextShouldConsiderFilteredNameHint() {
BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 32,
s -> s.chars().filter(Character::isAlphabetic).collect(StringBuilder::new,
StringBuilder::appendCodePoint, StringBuilder::append).toString()).create();
BindMarker marker1 = bindMarkers.next("foo1.bar?");
BindMarker marker2 = bindMarkers.next();
assertThat(marker1.getPlaceholder()).isEqualTo("@p0foobar");
assertThat(marker2.getPlaceholder()).isEqualTo("@p1");
}
@Test
void nextShouldConsiderNameLimit() {
BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 10).create();
BindMarker marker1 = bindMarkers.next("123456789");
assertThat(marker1.getPlaceholder()).isEqualTo("@p012345678");
}
@Test
void bindValueShouldBindByName() {
BindTarget bindTarget = mock(BindTarget.class);
BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 32).create();
bindMarkers.next().bind(bindTarget, "foo");
bindMarkers.next().bind(bindTarget, "bar");
verify(bindTarget).bind("p0", "foo");
verify(bindTarget).bind("p1", "bar");
}
@Test
void bindNullShouldBindByName() {
BindTarget bindTarget = mock(BindTarget.class);
BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 32).create();
bindMarkers.next(); // ignore
bindMarkers.next().bindNull(bindTarget, Integer.class);
verify(bindTarget).bindNull("p1", Integer.class);
}
}

View File

@ -0,0 +1,105 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
import reactor.core.publisher.Mono
/**
* Unit tests for [DatabaseClient] extensions.
*
* @author Sebastien Deleuze
* @author Jonas Bark
* @author Mark Paluch
*/
class DatabaseClientExtensionsTests {
@Test
fun bindByIndexShouldBindValue() {
val spec = mockk<DatabaseClient.GenericExecuteSpec>()
every { spec.bind(eq(0), any()) } returns spec
runBlocking {
spec.bind<String>(0, "foo")
}
verify {
spec.bind(0, Parameter.fromOrEmpty("foo", String::class.java))
}
}
@Test
fun bindByIndexShouldBindNull() {
val spec = mockk<DatabaseClient.GenericExecuteSpec>()
every { spec.bind(eq(0), any()) } returns spec
runBlocking {
spec.bind<String>(0, null)
}
verify {
spec.bind(0, Parameter.empty(String::class.java))
}
}
@Test
fun bindByNameShouldBindValue() {
val spec = mockk<DatabaseClient.GenericExecuteSpec>()
every { spec.bind(eq("field"), any()) } returns spec
runBlocking {
spec.bind<String>("field", "foo")
}
verify {
spec.bind("field", Parameter.fromOrEmpty("foo", String::class.java))
}
}
@Test
fun bindByNameShouldBindNull() {
val spec = mockk<DatabaseClient.GenericExecuteSpec>()
every { spec.bind(eq("field"), any()) } returns spec
runBlocking {
spec.bind<String>("field", null)
}
verify {
spec.bind("field", Parameter.empty(String::class.java))
}
}
@Test
fun genericExecuteSpecAwait() {
val spec = mockk<DatabaseClient.GenericExecuteSpec>()
every { spec.then() } returns Mono.empty()
runBlocking {
spec.await()
}
verify {
spec.then()
}
}
}

View File

@ -0,0 +1,166 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.jupiter.api.Test
import org.springframework.dao.EmptyResultDataAccessException
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
/**
* Unit tests for [RowsFetchSpec] extensions.
*
* @author Sebastien Deleuze
* @author Mark Paluch
*/
class RowsFetchSpecExtensionsTests {
@Test
fun awaitOneWithValue() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.one() } returns Mono.just("foo")
runBlocking {
assertThat(spec.awaitOne()).isEqualTo("foo")
}
verify {
spec.one()
}
}
@Test
fun awaitOneWithNull() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.one() } returns Mono.empty()
assertThatExceptionOfType(EmptyResultDataAccessException::class.java).isThrownBy {
runBlocking { spec.awaitOne() }
}
verify {
spec.one()
}
}
@Test
fun awaitOneOrNullWithValue() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.one() } returns Mono.just("foo")
runBlocking {
assertThat(spec.awaitOneOrNull()).isEqualTo("foo")
}
verify {
spec.one()
}
}
@Test
fun awaitOneOrNullWithNull() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.one() } returns Mono.empty()
runBlocking {
assertThat(spec.awaitOneOrNull()).isNull()
}
verify {
spec.one()
}
}
@Test
fun awaitFirstWithValue() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.first() } returns Mono.just("foo")
runBlocking {
assertThat(spec.awaitFirst()).isEqualTo("foo")
}
verify {
spec.first()
}
}
@Test
fun awaitFirstWithNull() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.first() } returns Mono.empty()
assertThatExceptionOfType(EmptyResultDataAccessException::class.java).isThrownBy {
runBlocking { spec.awaitFirst() }
}
verify {
spec.first()
}
}
@Test
fun awaitFirstOrNullWithValue() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.first() } returns Mono.just("foo")
runBlocking {
assertThat(spec.awaitFirstOrNull()).isEqualTo("foo")
}
verify {
spec.first()
}
}
@Test
fun awaitFirstOrNullWithNull() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.first() } returns Mono.empty()
runBlocking {
assertThat(spec.awaitFirstOrNull()).isNull()
}
verify {
spec.first()
}
}
@Test
@ExperimentalCoroutinesApi
fun allAsFlow() {
val spec = mockk<RowsFetchSpec<String>>()
every { spec.all() } returns Flux.just("foo", "bar", "baz")
runBlocking {
assertThat(spec.flow().toList()).contains("foo", "bar", "baz")
}
verify {
spec.all()
}
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.r2dbc.core
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import reactor.core.publisher.Mono
/**
* Unit tests for [UpdatedRowsFetchSpec] extensions.
*
* @author Fred Montariol
*/
class UpdatedRowsFetchSpecExtensionsTests {
@Test
fun awaitRowsUpdatedWithValue() {
val spec = mockk<UpdatedRowsFetchSpec>()
every { spec.rowsUpdated() } returns Mono.just(42)
runBlocking {
assertThat(spec.awaitRowsUpdated()).isEqualTo(42)
}
verify {
spec.rowsUpdated()
}
}
}

View File

@ -0,0 +1,5 @@
-- Failed DROP can be ignored if necessary
drop table users;
-- Create the test table
create table users (last_name varchar(50) not null);

View File

@ -0,0 +1,3 @@
drop table users if exists;
create table users (last_name varchar(50) not null);

View File

@ -0,0 +1,2 @@
insert into users (last_name) values ('Heisenberg')@@
insert into users (last_name) values ('Jesse')@@

Some files were not shown because too many files have changed in this diff Show More