diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index ba3d42937c0..31b9d15d923 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,7 +99,7 @@ public class SubProtocolWebSocketHandler private final ReentrantLock sessionCheckLock = new ReentrantLock(); - private final Stats stats = new Stats(); + private final DefaultStats stats = new DefaultStats(); private volatile boolean running = false; @@ -253,6 +253,15 @@ public class SubProtocolWebSocketHandler return this.stats.toString(); } + /** + * Return a {@link Stats} object that containers various session counters. + * @since 5.2 + */ + public Stats getStats() { + return this.stats; + } + + @Override public final void start() { @@ -560,7 +569,28 @@ public class SubProtocolWebSocketHandler } - private class Stats { + /** + * Contract for access to session counters. + * @since 5.2 + */ + public interface Stats { + + int getTotalSessions(); + + int getWebSocketSessions(); + + int getHttpStreamingSessions(); + + int getHttpPollingSessions(); + + int getLimitExceededSessions(); + + int getNoMessagesReceivedSessions(); + + int getTransportErrorSessions(); + } + + private class DefaultStats implements Stats { private final AtomicInteger total = new AtomicInteger(); @@ -576,28 +606,64 @@ public class SubProtocolWebSocketHandler private final AtomicInteger transportError = new AtomicInteger(); - public void incrementSessionCount(WebSocketSession session) { + + @Override + public int getTotalSessions() { + return this.total.get(); + } + + @Override + public int getWebSocketSessions() { + return this.webSocket.get(); + } + + @Override + public int getHttpStreamingSessions() { + return this.httpStreaming.get(); + } + + @Override + public int getHttpPollingSessions() { + return this.httpPolling.get(); + } + + @Override + public int getLimitExceededSessions() { + return this.limitExceeded.get(); + } + + @Override + public int getNoMessagesReceivedSessions() { + return this.noMessagesReceived.get(); + } + + @Override + public int getTransportErrorSessions() { + return this.transportError.get(); + } + + void incrementSessionCount(WebSocketSession session) { getCountFor(session).incrementAndGet(); this.total.incrementAndGet(); } - public void decrementSessionCount(WebSocketSession session) { + void decrementSessionCount(WebSocketSession session) { getCountFor(session).decrementAndGet(); } - public void incrementLimitExceededCount() { + void incrementLimitExceededCount() { this.limitExceeded.incrementAndGet(); } - public void incrementNoMessagesReceivedCount() { + void incrementNoMessagesReceivedCount() { this.noMessagesReceived.incrementAndGet(); } - public void incrementTransportError() { + void incrementTransportError() { this.transportError.incrementAndGet(); } - private AtomicInteger getCountFor(WebSocketSession session) { + AtomicInteger getCountFor(WebSocketSession session) { if (session instanceof PollingSockJsSession) { return this.httpPolling; }