diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java index 86d7a7bfa58..013aa730ebb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -20,9 +20,11 @@ import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -49,6 +51,27 @@ public class StompEncoder { private static final Log logger = LogFactory.getLog(StompEncoder.class); + private static final int HEADER_KEY_CACHE_LIMIT = 32; + + + private final Map headerKeyAccessCache = + new ConcurrentHashMap(HEADER_KEY_CACHE_LIMIT); + + @SuppressWarnings("serial") + private final Map headerKeyUpdateCache = + new LinkedHashMap(HEADER_KEY_CACHE_LIMIT, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + if (size() > HEADER_KEY_CACHE_LIMIT) { + headerKeyAccessCache.remove(eldest.getKey()); + return true; + } + else { + return false; + } + } + }; + /** * Encodes the given STOMP {@code message} into a {@code byte[]} @@ -129,11 +152,11 @@ public class StompEncoder { values = Collections.singletonList(StompHeaderAccessor.getPasscode(headers)); } - byte[] encodedKey = encodeHeaderString(entry.getKey(), shouldEscape); + byte[] encodedKey = encodeHeaderKey(entry.getKey(), shouldEscape); for (String value : values) { output.write(encodedKey); output.write(COLON); - output.write(encodeHeaderString(value, shouldEscape)); + output.write(encodeHeaderValue(value, shouldEscape)); output.write(LF); } } @@ -146,7 +169,23 @@ public class StompEncoder { } } - private byte[] encodeHeaderString(String input, boolean escape) { + private byte[] encodeHeaderKey(String input, boolean escape) { + String inputToUse = (escape ? escape(input) : input); + if (this.headerKeyAccessCache.containsKey(inputToUse)) { + return this.headerKeyAccessCache.get(inputToUse); + } + synchronized (this.headerKeyUpdateCache) { + byte[] bytes = this.headerKeyUpdateCache.get(inputToUse); + if (bytes == null) { + bytes = inputToUse.getBytes(StompDecoder.UTF8_CHARSET); + this.headerKeyAccessCache.put(inputToUse, bytes); + this.headerKeyUpdateCache.put(inputToUse, bytes); + } + return bytes; + } + } + + private byte[] encodeHeaderValue(String input, boolean escape) { String inputToUse = (escape ? escape(input) : input); return inputToUse.getBytes(StompDecoder.UTF8_CHARSET); } @@ -156,26 +195,38 @@ public class StompEncoder { * "Value Encoding". */ private String escape(String inString) { - StringBuilder sb = new StringBuilder(inString.length()); + StringBuilder sb = null; for (int i = 0; i < inString.length(); i++) { char c = inString.charAt(i); if (c == '\\') { + sb = getStringBuilder(sb, inString, i); sb.append("\\\\"); } else if (c == ':') { + sb = getStringBuilder(sb, inString, i); sb.append("\\c"); } else if (c == '\n') { - sb.append("\\n"); + sb = getStringBuilder(sb, inString, i); + sb.append("\\n"); } else if (c == '\r') { + sb = getStringBuilder(sb, inString, i); sb.append("\\r"); } - else { + else if (sb != null){ sb.append(c); } } - return sb.toString(); + return (sb != null ? sb.toString() : inString); + } + + private StringBuilder getStringBuilder(StringBuilder sb, String inString, int i) { + if (sb == null) { + sb = new StringBuilder(inString.length()); + sb.append(inString.substring(0, i)); + } + return sb; } private void writeBody(byte[] payload, DataOutputStream output) throws IOException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index b83a1dc5e44..c19ff778078 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -97,9 +97,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @SuppressWarnings("deprecation") private org.springframework.messaging.simp.user.UserSessionRegistry userSessionRegistry; - private final StompEncoder stompEncoder = new StompEncoder(); + private StompEncoder stompEncoder = new StompEncoder(); - private final StompDecoder stompDecoder = new StompDecoder(); + private StompDecoder stompDecoder = new StompDecoder(); private final Map decoders = new ConcurrentHashMap(); @@ -171,6 +171,24 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return this.userSessionRegistry; } + /** + * Configure a {@link StompEncoder} for encoding STOMP frames + * @param encoder the encoder + * @since 4.3.5 + */ + public void setEncoder(StompEncoder encoder) { + this.stompEncoder = encoder; + } + + /** + * Configure a {@link StompDecoder} for decoding STOMP frames + * @param decoder the decoder + * @since 4.3.5 + */ + public void setDecoder(StompDecoder decoder) { + this.stompDecoder = decoder; + } + /** * Configure a {@link MessageHeaderInitializer} to apply to the headers of all * messages created from decoded STOMP frames and other messages sent to the