mirror of https://github.com/apache/kafka.git
				
				
				
			KAFKA-9652: Fix throttle metric in RequestChannel and request log due to KIP-219 (#8567)
After KIP-219, responses are sent immediately and we rely on a combination of clients and muting of the channel to throttle. The result of this is that we need to track `apiThrottleTimeMs` as an explicit value instead of inferring it. On the other hand, we no longer need `apiRemoteCompleteTimeNanos`. Extend `BaseQuotaTest` to verify that throttle time in the request channel metrics are being set. Given the nature of the throttling numbers, the test is not particularly precise. I included a few clean-ups: * Pass KafkaMetric to QuotaViolationException so that the caller doesn't have to retrieve it from the metrics registry. * Inline Supplier in SocketServer (use SAM). * Reduce redundant `time.milliseconds` and `time.nanoseconds`calls. * Use monotonic clock in ThrottledChannel and simplify `compareTo` method. * Simplify `TimerTaskList.compareTo`. * Consolidate the number of places where we update `apiLocalCompleteTimeNanos` and `responseCompleteTimeNanos`. * Added `toString` to ByteBufferSend` and `MultiRecordsSend`. * Restrict access to methods in `QuotaTestClients` to expose only what we need to. Reviewers: Jun Rao <junrao@gmail.com>
This commit is contained in:
		
							parent
							
								
									8a83025109
								
							
						
					
					
						commit
						322b10964c
					
				|  | @ -17,7 +17,6 @@ | ||||||
| package org.apache.kafka.common.metrics; | package org.apache.kafka.common.metrics; | ||||||
| 
 | 
 | ||||||
| import org.apache.kafka.common.KafkaException; | import org.apache.kafka.common.KafkaException; | ||||||
| import org.apache.kafka.common.MetricName; |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * Thrown when a sensor records a value that causes a metric to go outside the bounds configured as its quota |  * Thrown when a sensor records a value that causes a metric to go outside the bounds configured as its quota | ||||||
|  | @ -25,18 +24,18 @@ import org.apache.kafka.common.MetricName; | ||||||
| public class QuotaViolationException extends KafkaException { | public class QuotaViolationException extends KafkaException { | ||||||
| 
 | 
 | ||||||
|     private static final long serialVersionUID = 1L; |     private static final long serialVersionUID = 1L; | ||||||
|     private final MetricName metricName; |     private final KafkaMetric metric; | ||||||
|     private final double value; |     private final double value; | ||||||
|     private final double bound; |     private final double bound; | ||||||
| 
 | 
 | ||||||
|     public QuotaViolationException(MetricName metricName, double value, double bound) { |     public QuotaViolationException(KafkaMetric metric, double value, double bound) { | ||||||
|         this.metricName = metricName; |         this.metric = metric; | ||||||
|         this.value = value; |         this.value = value; | ||||||
|         this.bound = bound; |         this.bound = bound; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public MetricName metricName() { |     public KafkaMetric metric() { | ||||||
|         return metricName; |         return metric; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public double value() { |     public double value() { | ||||||
|  | @ -51,7 +50,7 @@ public class QuotaViolationException extends KafkaException { | ||||||
|     public String toString() { |     public String toString() { | ||||||
|         return getClass().getName() |         return getClass().getName() | ||||||
|                 + ": '" |                 + ": '" | ||||||
|                 + metricName |                 + metric.metricName() | ||||||
|                 + "' violated quota. Actual: " |                 + "' violated quota. Actual: " | ||||||
|                 + value |                 + value | ||||||
|                 + ", Threshold: " |                 + ", Threshold: " | ||||||
|  |  | ||||||
|  | @ -209,8 +209,7 @@ public final class Sensor { | ||||||
|                 if (quota != null) { |                 if (quota != null) { | ||||||
|                     double value = metric.measurableValue(timeMs); |                     double value = metric.measurableValue(timeMs); | ||||||
|                     if (!quota.acceptable(value)) { |                     if (!quota.acceptable(value)) { | ||||||
|                         throw new QuotaViolationException(metric.metricName(), value, |                         throw new QuotaViolationException(metric, value, quota.bound()); | ||||||
|                             quota.bound()); |  | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|  | @ -68,4 +68,14 @@ public class ByteBufferSend implements Send { | ||||||
|     public long remaining() { |     public long remaining() { | ||||||
|         return remaining; |         return remaining; | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public String toString() { | ||||||
|  |         return "ByteBufferSend(" + | ||||||
|  |             "destination='" + destination + "'" + | ||||||
|  |             ", size=" + size + | ||||||
|  |             ", remaining=" + remaining + | ||||||
|  |             ", pending=" + pending + | ||||||
|  |             ')'; | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -118,6 +118,15 @@ public class MultiRecordsSend implements Send { | ||||||
|         return recordConversionStats; |         return recordConversionStats; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     @Override | ||||||
|  |     public String toString() { | ||||||
|  |         return "MultiRecordsSend(" + | ||||||
|  |             "dest='" + dest + "'" + | ||||||
|  |             ", size=" + size + | ||||||
|  |             ", totalWritten=" + totalWritten + | ||||||
|  |             ')'; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     private void updateRecordConversionStats(Send completedSend) { |     private void updateRecordConversionStats(Send completedSend) { | ||||||
|         // The underlying send might have accumulated statistics that need to be recorded. For example, |         // The underlying send might have accumulated statistics that need to be recorded. For example, | ||||||
|         // LazyDownConversionRecordsSend accumulates statistics related to the number of bytes down-converted, the amount |         // LazyDownConversionRecordsSend accumulates statistics related to the number of bytes down-converted, the amount | ||||||
|  |  | ||||||
|  | @ -87,8 +87,8 @@ object RequestChannel extends Logging { | ||||||
|     @volatile var apiLocalCompleteTimeNanos = -1L |     @volatile var apiLocalCompleteTimeNanos = -1L | ||||||
|     @volatile var responseCompleteTimeNanos = -1L |     @volatile var responseCompleteTimeNanos = -1L | ||||||
|     @volatile var responseDequeueTimeNanos = -1L |     @volatile var responseDequeueTimeNanos = -1L | ||||||
|     @volatile var apiRemoteCompleteTimeNanos = -1L |  | ||||||
|     @volatile var messageConversionsTimeNanos = 0L |     @volatile var messageConversionsTimeNanos = 0L | ||||||
|  |     @volatile var apiThrottleTimeMs = 0L | ||||||
|     @volatile var temporaryMemoryBytes = 0L |     @volatile var temporaryMemoryBytes = 0L | ||||||
|     @volatile var recordNetworkThreadTimeCallback: Option[Long => Unit] = None |     @volatile var recordNetworkThreadTimeCallback: Option[Long => Unit] = None | ||||||
| 
 | 
 | ||||||
|  | @ -170,16 +170,6 @@ object RequestChannel extends Logging { | ||||||
| 
 | 
 | ||||||
|     def updateRequestMetrics(networkThreadTimeNanos: Long, response: Response): Unit = { |     def updateRequestMetrics(networkThreadTimeNanos: Long, response: Response): Unit = { | ||||||
|       val endTimeNanos = Time.SYSTEM.nanoseconds |       val endTimeNanos = Time.SYSTEM.nanoseconds | ||||||
|       // In some corner cases, apiLocalCompleteTimeNanos may not be set when the request completes if the remote |  | ||||||
|       // processing time is really small. This value is set in KafkaApis from a request handling thread. |  | ||||||
|       // This may be read in a network thread before the actual update happens in KafkaApis which will cause us to |  | ||||||
|       // see a negative value here. In that case, use responseCompleteTimeNanos as apiLocalCompleteTimeNanos. |  | ||||||
|       if (apiLocalCompleteTimeNanos < 0) |  | ||||||
|         apiLocalCompleteTimeNanos = responseCompleteTimeNanos |  | ||||||
|       // If the apiRemoteCompleteTimeNanos is not set (i.e., for requests that do not go through a purgatory), then it is |  | ||||||
|       // the same as responseCompleteTimeNanos. |  | ||||||
|       if (apiRemoteCompleteTimeNanos < 0) |  | ||||||
|         apiRemoteCompleteTimeNanos = responseCompleteTimeNanos |  | ||||||
| 
 | 
 | ||||||
|       /** |       /** | ||||||
|        * Converts nanos to millis with micros precision as additional decimal places in the request log have low |        * Converts nanos to millis with micros precision as additional decimal places in the request log have low | ||||||
|  | @ -193,8 +183,7 @@ object RequestChannel extends Logging { | ||||||
| 
 | 
 | ||||||
|       val requestQueueTimeMs = nanosToMs(requestDequeueTimeNanos - startTimeNanos) |       val requestQueueTimeMs = nanosToMs(requestDequeueTimeNanos - startTimeNanos) | ||||||
|       val apiLocalTimeMs = nanosToMs(apiLocalCompleteTimeNanos - requestDequeueTimeNanos) |       val apiLocalTimeMs = nanosToMs(apiLocalCompleteTimeNanos - requestDequeueTimeNanos) | ||||||
|       val apiRemoteTimeMs = nanosToMs(apiRemoteCompleteTimeNanos - apiLocalCompleteTimeNanos) |       val apiRemoteTimeMs = nanosToMs(responseCompleteTimeNanos - apiLocalCompleteTimeNanos) | ||||||
|       val apiThrottleTimeMs = nanosToMs(responseCompleteTimeNanos - apiRemoteCompleteTimeNanos) |  | ||||||
|       val responseQueueTimeMs = nanosToMs(responseDequeueTimeNanos - responseCompleteTimeNanos) |       val responseQueueTimeMs = nanosToMs(responseDequeueTimeNanos - responseCompleteTimeNanos) | ||||||
|       val responseSendTimeMs = nanosToMs(endTimeNanos - responseDequeueTimeNanos) |       val responseSendTimeMs = nanosToMs(endTimeNanos - responseDequeueTimeNanos) | ||||||
|       val messageConversionsTimeMs = nanosToMs(messageConversionsTimeNanos) |       val messageConversionsTimeMs = nanosToMs(messageConversionsTimeNanos) | ||||||
|  | @ -215,7 +204,7 @@ object RequestChannel extends Logging { | ||||||
|         m.requestQueueTimeHist.update(Math.round(requestQueueTimeMs)) |         m.requestQueueTimeHist.update(Math.round(requestQueueTimeMs)) | ||||||
|         m.localTimeHist.update(Math.round(apiLocalTimeMs)) |         m.localTimeHist.update(Math.round(apiLocalTimeMs)) | ||||||
|         m.remoteTimeHist.update(Math.round(apiRemoteTimeMs)) |         m.remoteTimeHist.update(Math.round(apiRemoteTimeMs)) | ||||||
|         m.throttleTimeHist.update(Math.round(apiThrottleTimeMs)) |         m.throttleTimeHist.update(apiThrottleTimeMs) | ||||||
|         m.responseQueueTimeHist.update(Math.round(responseQueueTimeMs)) |         m.responseQueueTimeHist.update(Math.round(responseQueueTimeMs)) | ||||||
|         m.responseSendTimeHist.update(Math.round(responseSendTimeMs)) |         m.responseSendTimeHist.update(Math.round(responseSendTimeMs)) | ||||||
|         m.totalTimeHist.update(Math.round(totalTimeMs)) |         m.totalTimeHist.update(Math.round(totalTimeMs)) | ||||||
|  | @ -276,12 +265,6 @@ object RequestChannel extends Logging { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   abstract class Response(val request: Request) { |   abstract class Response(val request: Request) { | ||||||
|     locally { |  | ||||||
|       val nowNs = Time.SYSTEM.nanoseconds |  | ||||||
|       request.responseCompleteTimeNanos = nowNs |  | ||||||
|       if (request.apiLocalCompleteTimeNanos == -1L) |  | ||||||
|         request.apiLocalCompleteTimeNanos = nowNs |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     def processor: Int = request.processor |     def processor: Int = request.processor | ||||||
| 
 | 
 | ||||||
|  | @ -326,7 +309,7 @@ object RequestChannel extends Logging { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| class RequestChannel(val queueSize: Int, val metricNamePrefix : String) extends KafkaMetricsGroup { | class RequestChannel(val queueSize: Int, val metricNamePrefix : String, time: Time) extends KafkaMetricsGroup { | ||||||
|   import RequestChannel._ |   import RequestChannel._ | ||||||
|   val metrics = new RequestChannel.Metrics |   val metrics = new RequestChannel.Metrics | ||||||
|   private val requestQueue = new ArrayBlockingQueue[BaseRequest](queueSize) |   private val requestQueue = new ArrayBlockingQueue[BaseRequest](queueSize) | ||||||
|  | @ -362,6 +345,7 @@ class RequestChannel(val queueSize: Int, val metricNamePrefix : String) extends | ||||||
| 
 | 
 | ||||||
|   /** Send a response back to the socket server to be sent over the network */ |   /** Send a response back to the socket server to be sent over the network */ | ||||||
|   def sendResponse(response: RequestChannel.Response): Unit = { |   def sendResponse(response: RequestChannel.Response): Unit = { | ||||||
|  | 
 | ||||||
|     if (isTraceEnabled) { |     if (isTraceEnabled) { | ||||||
|       val requestHeader = response.request.header |       val requestHeader = response.request.header | ||||||
|       val message = response match { |       val message = response match { | ||||||
|  | @ -379,6 +363,18 @@ class RequestChannel(val queueSize: Int, val metricNamePrefix : String) extends | ||||||
|       trace(message) |       trace(message) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     response match { | ||||||
|  |       // We should only send one of the following per request | ||||||
|  |       case _: SendResponse | _: NoOpResponse | _: CloseConnectionResponse => | ||||||
|  |         val request = response.request | ||||||
|  |         val timeNanos = time.nanoseconds() | ||||||
|  |         request.responseCompleteTimeNanos = timeNanos | ||||||
|  |         if (request.apiLocalCompleteTimeNanos == -1L) | ||||||
|  |           request.apiLocalCompleteTimeNanos = timeNanos | ||||||
|  |       // For a given request, these may happen in addition to one in the previous section, skip updating the metrics | ||||||
|  |       case _: StartThrottlingResponse | _: EndThrottlingResponse => () | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     val processor = processors.get(response.processor) |     val processor = processors.get(response.processor) | ||||||
|     // The processor may be null if it was shutdown. In this case, the connections |     // The processor may be null if it was shutdown. In this case, the connections | ||||||
|     // are closed, so the response is dropped. |     // are closed, so the response is dropped. | ||||||
|  | @ -444,7 +440,8 @@ class RequestMetrics(name: String) extends KafkaMetricsGroup { | ||||||
|   val localTimeHist = newHistogram(LocalTimeMs, biased = true, tags) |   val localTimeHist = newHistogram(LocalTimeMs, biased = true, tags) | ||||||
|   // time a request takes to wait on remote brokers (currently only relevant to fetch and produce requests) |   // time a request takes to wait on remote brokers (currently only relevant to fetch and produce requests) | ||||||
|   val remoteTimeHist = newHistogram(RemoteTimeMs, biased = true, tags) |   val remoteTimeHist = newHistogram(RemoteTimeMs, biased = true, tags) | ||||||
|   // time a request is throttled |   // time a request is throttled, not part of the request processing time (throttling is done at the client level | ||||||
|  |   // for clients that support KIP-219 and by muting the channel for the rest) | ||||||
|   val throttleTimeHist = newHistogram(ThrottleTimeMs, biased = true, tags) |   val throttleTimeHist = newHistogram(ThrottleTimeMs, biased = true, tags) | ||||||
|   // time a response spent in a response queue |   // time a response spent in a response queue | ||||||
|   val responseQueueTimeHist = newHistogram(ResponseQueueTimeMs, biased = true, tags) |   val responseQueueTimeHist = newHistogram(ResponseQueueTimeMs, biased = true, tags) | ||||||
|  |  | ||||||
|  | @ -25,7 +25,6 @@ import java.util | ||||||
| import java.util.Optional | import java.util.Optional | ||||||
| import java.util.concurrent._ | import java.util.concurrent._ | ||||||
| import java.util.concurrent.atomic._ | import java.util.concurrent.atomic._ | ||||||
| import java.util.function.Supplier |  | ||||||
| 
 | 
 | ||||||
| import kafka.cluster.{BrokerEndPoint, EndPoint} | import kafka.cluster.{BrokerEndPoint, EndPoint} | ||||||
| import kafka.metrics.KafkaMetricsGroup | import kafka.metrics.KafkaMetricsGroup | ||||||
|  | @ -92,11 +91,12 @@ class SocketServer(val config: KafkaConfig, | ||||||
|   // data-plane |   // data-plane | ||||||
|   private val dataPlaneProcessors = new ConcurrentHashMap[Int, Processor]() |   private val dataPlaneProcessors = new ConcurrentHashMap[Int, Processor]() | ||||||
|   private[network] val dataPlaneAcceptors = new ConcurrentHashMap[EndPoint, Acceptor]() |   private[network] val dataPlaneAcceptors = new ConcurrentHashMap[EndPoint, Acceptor]() | ||||||
|   val dataPlaneRequestChannel = new RequestChannel(maxQueuedRequests, DataPlaneMetricPrefix) |   val dataPlaneRequestChannel = new RequestChannel(maxQueuedRequests, DataPlaneMetricPrefix, time) | ||||||
|   // control-plane |   // control-plane | ||||||
|   private var controlPlaneProcessorOpt : Option[Processor] = None |   private var controlPlaneProcessorOpt : Option[Processor] = None | ||||||
|   private[network] var controlPlaneAcceptorOpt : Option[Acceptor] = None |   private[network] var controlPlaneAcceptorOpt : Option[Acceptor] = None | ||||||
|   val controlPlaneRequestChannelOpt: Option[RequestChannel] = config.controlPlaneListenerName.map(_ => new RequestChannel(20, ControlPlaneMetricPrefix)) |   val controlPlaneRequestChannelOpt: Option[RequestChannel] = config.controlPlaneListenerName.map(_ => | ||||||
|  |     new RequestChannel(20, ControlPlaneMetricPrefix, time)) | ||||||
| 
 | 
 | ||||||
|   private var nextProcessorId = 0 |   private var nextProcessorId = 0 | ||||||
|   private var connectionQuotas: ConnectionQuotas = _ |   private var connectionQuotas: ConnectionQuotas = _ | ||||||
|  | @ -908,10 +908,6 @@ private[kafka] class Processor(val id: Int, | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private def nowNanosSupplier = new Supplier[java.lang.Long] { |  | ||||||
|     override def get(): java.lang.Long = time.nanoseconds() |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   private def poll(): Unit = { |   private def poll(): Unit = { | ||||||
|     val pollTimeout = if (newConnections.isEmpty) 300 else 0 |     val pollTimeout = if (newConnections.isEmpty) 300 else 0 | ||||||
|     try selector.poll(pollTimeout) |     try selector.poll(pollTimeout) | ||||||
|  | @ -929,7 +925,8 @@ private[kafka] class Processor(val id: Int, | ||||||
|         openOrClosingChannel(receive.source) match { |         openOrClosingChannel(receive.source) match { | ||||||
|           case Some(channel) => |           case Some(channel) => | ||||||
|             val header = RequestHeader.parse(receive.payload) |             val header = RequestHeader.parse(receive.payload) | ||||||
|             if (header.apiKey == ApiKeys.SASL_HANDSHAKE && channel.maybeBeginServerReauthentication(receive, nowNanosSupplier)) |             if (header.apiKey == ApiKeys.SASL_HANDSHAKE && channel.maybeBeginServerReauthentication(receive, | ||||||
|  |               () => time.nanoseconds())) | ||||||
|               trace(s"Begin re-authentication: $channel") |               trace(s"Begin re-authentication: $channel") | ||||||
|             else { |             else { | ||||||
|               val nowNanos = time.nanoseconds() |               val nowNanos = time.nanoseconds() | ||||||
|  |  | ||||||
|  | @ -250,18 +250,16 @@ class ClientQuotaManager(private val config: ClientQuotaManagerConfig, | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   def recordAndGetThrottleTimeMs(session: Session, clientId: String, value: Double, timeMs: Long): Int = { |   def recordAndGetThrottleTimeMs(session: Session, clientId: String, value: Double, timeMs: Long): Int = { | ||||||
|     var throttleTimeMs = 0 |  | ||||||
|     val clientSensors = getOrCreateQuotaSensors(session, clientId) |     val clientSensors = getOrCreateQuotaSensors(session, clientId) | ||||||
|     try { |     try { | ||||||
|       clientSensors.quotaSensor.record(value, timeMs) |       clientSensors.quotaSensor.record(value, timeMs) | ||||||
|  |       0 | ||||||
|     } catch { |     } catch { | ||||||
|       case _: QuotaViolationException => |       case e: QuotaViolationException => | ||||||
|         // Compute the delay |         val throttleTimeMs = throttleTime(e.value, e.bound, windowSize(e.metric, timeMs)).toInt | ||||||
|         val clientMetric = metrics.metrics().get(clientRateMetricName(clientSensors.metricTags)) |         debug(s"Quota violated for sensor (${clientSensors.quotaSensor.name}). Delay time: ($throttleTimeMs)") | ||||||
|         throttleTimeMs = throttleTime(clientMetric).toInt |         throttleTimeMs | ||||||
|         debug("Quota violated for sensor (%s). Delay time: (%d)".format(clientSensors.quotaSensor.name(), throttleTimeMs)) |  | ||||||
|     } |     } | ||||||
|     throttleTimeMs |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** "Unrecord" the given value that has already been recorded for the given user/client by recording a negative value |   /** "Unrecord" the given value that has already been recorded for the given user/client by recording a negative value | ||||||
|  | @ -337,16 +335,16 @@ class ClientQuotaManager(private val config: ClientQuotaManagerConfig, | ||||||
|    * we need to add a delay of X to W such that O * W / (W + X) = T. |    * we need to add a delay of X to W such that O * W / (W + X) = T. | ||||||
|    * Solving for X, we get X = (O - T)/T * W. |    * Solving for X, we get X = (O - T)/T * W. | ||||||
|    */ |    */ | ||||||
|   protected def throttleTime(clientMetric: KafkaMetric): Long = { |   protected def throttleTime(quotaValue: Double, quotaBound: Double, windowSize: Long): Long = { | ||||||
|     val config = clientMetric.config |     val difference = quotaValue - quotaBound | ||||||
|     val rateMetric: Rate = measurableAsRate(clientMetric.metricName(), clientMetric.measurable()) |  | ||||||
|     val quota = config.quota() |  | ||||||
|     val difference = clientMetric.metricValue.asInstanceOf[Double] - quota.bound |  | ||||||
|     // Use the precise window used by the rate calculation |     // Use the precise window used by the rate calculation | ||||||
|     val throttleTimeMs = difference / quota.bound * rateMetric.windowSize(config, time.milliseconds()) |     val throttleTimeMs = difference / quotaBound * windowSize | ||||||
|     throttleTimeMs.round |     Math.round(throttleTimeMs) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   private def windowSize(metric: KafkaMetric, timeMs: Long): Long = | ||||||
|  |     measurableAsRate(metric.metricName, metric.measurable).windowSize(metric.config, timeMs) | ||||||
|  | 
 | ||||||
|   // Casting to Rate because we only use Rate in Quota computation |   // Casting to Rate because we only use Rate in Quota computation | ||||||
|   private def measurableAsRate(name: MetricName, measurable: Measurable): Rate = { |   private def measurableAsRate(name: MetricName, measurable: Measurable): Rate = { | ||||||
|     measurable match { |     measurable match { | ||||||
|  |  | ||||||
|  | @ -46,17 +46,12 @@ class ClientRequestQuotaManager(private val config: ClientQuotaManagerConfig, | ||||||
|     * @param request client request |     * @param request client request | ||||||
|     * @return Number of milliseconds to throttle in case of quota violation. Zero otherwise |     * @return Number of milliseconds to throttle in case of quota violation. Zero otherwise | ||||||
|     */ |     */ | ||||||
|   def maybeRecordAndGetThrottleTimeMs(request: RequestChannel.Request): Int = { |   def maybeRecordAndGetThrottleTimeMs(request: RequestChannel.Request, timeMs: Long): Int = { | ||||||
|     if (request.apiRemoteCompleteTimeNanos == -1) { |  | ||||||
|       // When this callback is triggered, the remote API call has completed |  | ||||||
|       request.apiRemoteCompleteTimeNanos = time.nanoseconds |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if (quotasEnabled) { |     if (quotasEnabled) { | ||||||
|       request.recordNetworkThreadTimeCallback = Some(timeNanos => recordNoThrottle( |       request.recordNetworkThreadTimeCallback = Some(timeNanos => recordNoThrottle( | ||||||
|         getOrCreateQuotaSensors(request.session, request.header.clientId), nanosToPercentage(timeNanos))) |         getOrCreateQuotaSensors(request.session, request.header.clientId), nanosToPercentage(timeNanos))) | ||||||
|       recordAndGetThrottleTimeMs(request.session, request.header.clientId, |       recordAndGetThrottleTimeMs(request.session, request.header.clientId, | ||||||
|         nanosToPercentage(request.requestThreadTimeNanos), time.milliseconds()) |         nanosToPercentage(request.requestThreadTimeNanos), timeMs) | ||||||
|     } else { |     } else { | ||||||
|       0 |       0 | ||||||
|     } |     } | ||||||
|  | @ -69,8 +64,8 @@ class ClientRequestQuotaManager(private val config: ClientQuotaManagerConfig, | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   override protected def throttleTime(clientMetric: KafkaMetric): Long = { |   override protected def throttleTime(quotaValue: Double, quotaBound: Double, windowSize: Long): Long = { | ||||||
|     math.min(super.throttleTime(clientMetric), maxThrottleTimeMs) |     math.min(super.throttleTime(quotaValue, quotaBound, windowSize), maxThrottleTimeMs) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   override protected def clientRateMetricName(quotaMetricTags: Map[String, String]): MetricName = { |   override protected def clientRateMetricName(quotaMetricTags: Map[String, String]): MetricName = { | ||||||
|  |  | ||||||
|  | @ -539,20 +539,21 @@ class KafkaApis(val requestChannel: RequestChannel, | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|       // When this callback is triggered, the remote API call has completed |  | ||||||
|       request.apiRemoteCompleteTimeNanos = time.nanoseconds |  | ||||||
| 
 |  | ||||||
|       // Record both bandwidth and request quota-specific values and throttle by muting the channel if any of the quotas |       // Record both bandwidth and request quota-specific values and throttle by muting the channel if any of the quotas | ||||||
|       // have been violated. If both quotas have been violated, use the max throttle time between the two quotas. Note |       // have been violated. If both quotas have been violated, use the max throttle time between the two quotas. Note | ||||||
|       // that the request quota is not enforced if acks == 0. |       // that the request quota is not enforced if acks == 0. | ||||||
|       val bandwidthThrottleTimeMs = quotas.produce.maybeRecordAndGetThrottleTimeMs(request, numBytesAppended, time.milliseconds) |       val timeMs = time.milliseconds() | ||||||
|       val requestThrottleTimeMs = if (produceRequest.acks == 0) 0 else quotas.request.maybeRecordAndGetThrottleTimeMs(request) |       val bandwidthThrottleTimeMs = quotas.produce.maybeRecordAndGetThrottleTimeMs(request, numBytesAppended, timeMs) | ||||||
|  |       val requestThrottleTimeMs = | ||||||
|  |         if (produceRequest.acks == 0) 0 | ||||||
|  |         else quotas.request.maybeRecordAndGetThrottleTimeMs(request, timeMs) | ||||||
|       val maxThrottleTimeMs = Math.max(bandwidthThrottleTimeMs, requestThrottleTimeMs) |       val maxThrottleTimeMs = Math.max(bandwidthThrottleTimeMs, requestThrottleTimeMs) | ||||||
|       if (maxThrottleTimeMs > 0) { |       if (maxThrottleTimeMs > 0) { | ||||||
|  |         request.apiThrottleTimeMs = maxThrottleTimeMs | ||||||
|         if (bandwidthThrottleTimeMs > requestThrottleTimeMs) { |         if (bandwidthThrottleTimeMs > requestThrottleTimeMs) { | ||||||
|           quotas.produce.throttle(request, bandwidthThrottleTimeMs, sendResponse) |           quotas.produce.throttle(request, bandwidthThrottleTimeMs, requestChannel.sendResponse) | ||||||
|         } else { |         } else { | ||||||
|           quotas.request.throttle(request, requestThrottleTimeMs, sendResponse) |           quotas.request.throttle(request, requestThrottleTimeMs, requestChannel.sendResponse) | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|  | @ -741,10 +742,6 @@ class KafkaApis(val requestChannel: RequestChannel, | ||||||
|       } |       } | ||||||
|       erroneous.foreach { case (tp, data) => partitions.put(tp, data) } |       erroneous.foreach { case (tp, data) => partitions.put(tp, data) } | ||||||
| 
 | 
 | ||||||
|       // When this callback is triggered, the remote API call has completed. |  | ||||||
|       // Record time before any byte-rate throttling. |  | ||||||
|       request.apiRemoteCompleteTimeNanos = time.nanoseconds |  | ||||||
| 
 |  | ||||||
|       var unconvertedFetchResponse: FetchResponse[Records] = null |       var unconvertedFetchResponse: FetchResponse[Records] = null | ||||||
| 
 | 
 | ||||||
|       def createResponse(throttleTimeMs: Int): FetchResponse[BaseRecords] = { |       def createResponse(throttleTimeMs: Int): FetchResponse[BaseRecords] = { | ||||||
|  | @ -794,19 +791,20 @@ class KafkaApis(val requestChannel: RequestChannel, | ||||||
|         // quotas have been violated. If both quotas have been violated, use the max throttle time between the two |         // quotas have been violated. If both quotas have been violated, use the max throttle time between the two | ||||||
|         // quotas. When throttled, we unrecord the recorded bandwidth quota value |         // quotas. When throttled, we unrecord the recorded bandwidth quota value | ||||||
|         val responseSize = fetchContext.getResponseSize(partitions, versionId) |         val responseSize = fetchContext.getResponseSize(partitions, versionId) | ||||||
|         val timeMs = time.milliseconds |         val timeMs = time.milliseconds() | ||||||
|         val requestThrottleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request) |         val requestThrottleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request, timeMs) | ||||||
|         val bandwidthThrottleTimeMs = quotas.fetch.maybeRecordAndGetThrottleTimeMs(request, responseSize, timeMs) |         val bandwidthThrottleTimeMs = quotas.fetch.maybeRecordAndGetThrottleTimeMs(request, responseSize, timeMs) | ||||||
| 
 | 
 | ||||||
|         val maxThrottleTimeMs = math.max(bandwidthThrottleTimeMs, requestThrottleTimeMs) |         val maxThrottleTimeMs = math.max(bandwidthThrottleTimeMs, requestThrottleTimeMs) | ||||||
|         if (maxThrottleTimeMs > 0) { |         if (maxThrottleTimeMs > 0) { | ||||||
|  |           request.apiThrottleTimeMs = maxThrottleTimeMs | ||||||
|           // Even if we need to throttle for request quota violation, we should "unrecord" the already recorded value |           // Even if we need to throttle for request quota violation, we should "unrecord" the already recorded value | ||||||
|           // from the fetch quota because we are going to return an empty response. |           // from the fetch quota because we are going to return an empty response. | ||||||
|           quotas.fetch.unrecordQuotaSensor(request, responseSize, timeMs) |           quotas.fetch.unrecordQuotaSensor(request, responseSize, timeMs) | ||||||
|           if (bandwidthThrottleTimeMs > requestThrottleTimeMs) { |           if (bandwidthThrottleTimeMs > requestThrottleTimeMs) { | ||||||
|             quotas.fetch.throttle(request, bandwidthThrottleTimeMs, sendResponse) |             quotas.fetch.throttle(request, bandwidthThrottleTimeMs, requestChannel.sendResponse) | ||||||
|           } else { |           } else { | ||||||
|             quotas.request.throttle(request, requestThrottleTimeMs, sendResponse) |             quotas.request.throttle(request, requestThrottleTimeMs, requestChannel.sendResponse) | ||||||
|           } |           } | ||||||
|           // If throttling is required, return an empty response. |           // If throttling is required, return an empty response. | ||||||
|           unconvertedFetchResponse = fetchContext.getThrottledResponse(maxThrottleTimeMs) |           unconvertedFetchResponse = fetchContext.getThrottledResponse(maxThrottleTimeMs) | ||||||
|  | @ -3012,17 +3010,23 @@ class KafkaApis(val requestChannel: RequestChannel, | ||||||
|   private def sendResponseMaybeThrottle(request: RequestChannel.Request, |   private def sendResponseMaybeThrottle(request: RequestChannel.Request, | ||||||
|                                         createResponse: Int => AbstractResponse, |                                         createResponse: Int => AbstractResponse, | ||||||
|                                         onComplete: Option[Send => Unit] = None): Unit = { |                                         onComplete: Option[Send => Unit] = None): Unit = { | ||||||
|     val throttleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request) |     val throttleTimeMs = maybeRecordAndGetThrottleTimeMs(request) | ||||||
|     quotas.request.throttle(request, throttleTimeMs, sendResponse) |     quotas.request.throttle(request, throttleTimeMs, requestChannel.sendResponse) | ||||||
|     sendResponse(request, Some(createResponse(throttleTimeMs)), onComplete) |     sendResponse(request, Some(createResponse(throttleTimeMs)), onComplete) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private def sendErrorResponseMaybeThrottle(request: RequestChannel.Request, error: Throwable): Unit = { |   private def sendErrorResponseMaybeThrottle(request: RequestChannel.Request, error: Throwable): Unit = { | ||||||
|     val throttleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request) |     val throttleTimeMs = maybeRecordAndGetThrottleTimeMs(request) | ||||||
|     quotas.request.throttle(request, throttleTimeMs, sendResponse) |     quotas.request.throttle(request, throttleTimeMs, requestChannel.sendResponse) | ||||||
|     sendErrorOrCloseConnection(request, error, throttleTimeMs) |     sendErrorOrCloseConnection(request, error, throttleTimeMs) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   private def maybeRecordAndGetThrottleTimeMs(request: RequestChannel.Request): Int = { | ||||||
|  |     val throttleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request, time.milliseconds()) | ||||||
|  |     request.apiThrottleTimeMs = throttleTimeMs | ||||||
|  |     throttleTimeMs | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   private def sendResponseExemptThrottle(request: RequestChannel.Request, |   private def sendResponseExemptThrottle(request: RequestChannel.Request, | ||||||
|                                          response: AbstractResponse, |                                          response: AbstractResponse, | ||||||
|                                          onComplete: Option[Send => Unit] = None): Unit = { |                                          onComplete: Option[Send => Unit] = None): Unit = { | ||||||
|  | @ -3072,10 +3076,7 @@ class KafkaApis(val requestChannel: RequestChannel, | ||||||
|       case None => |       case None => | ||||||
|         new RequestChannel.NoOpResponse(request) |         new RequestChannel.NoOpResponse(request) | ||||||
|     } |     } | ||||||
|     sendResponse(response) |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   private def sendResponse(response: RequestChannel.Response): Unit = { |  | ||||||
|     requestChannel.sendResponse(response) |     requestChannel.sendResponse(response) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -107,7 +107,8 @@ class ReplicationQuotaManager(val config: ReplicationQuotaManagerConfig, | ||||||
|       sensor().checkQuotas() |       sensor().checkQuotas() | ||||||
|     } catch { |     } catch { | ||||||
|       case qve: QuotaViolationException => |       case qve: QuotaViolationException => | ||||||
|         trace("%s: Quota violated for sensor (%s), metric: (%s), metric-value: (%f), bound: (%f)".format(replicationType, sensor().name(), qve.metricName, qve.value, qve.bound)) |         trace(s"$replicationType: Quota violated for sensor (${sensor().name}), metric: (${qve.metric.metricName}), " + | ||||||
|  |           s"metric-value: (${qve.value}), bound: (${qve.bound})") | ||||||
|         return true |         return true | ||||||
|     } |     } | ||||||
|     false |     false | ||||||
|  |  | ||||||
|  | @ -33,9 +33,11 @@ import org.apache.kafka.common.utils.Time | ||||||
|   * @param throttleTimeMs Delay associated with this request |   * @param throttleTimeMs Delay associated with this request | ||||||
|   * @param channelThrottlingCallback Callback for channel throttling |   * @param channelThrottlingCallback Callback for channel throttling | ||||||
|   */ |   */ | ||||||
| class ThrottledChannel(val request: RequestChannel.Request, val time: Time, val throttleTimeMs: Int, channelThrottlingCallback: Response => Unit) | class ThrottledChannel(val request: RequestChannel.Request, val time: Time, val throttleTimeMs: Int, | ||||||
|  |                        channelThrottlingCallback: Response => Unit) | ||||||
|   extends Delayed with Logging { |   extends Delayed with Logging { | ||||||
|   var endTime = time.milliseconds + throttleTimeMs | 
 | ||||||
|  |   private val endTimeNanos = time.nanoseconds() + TimeUnit.MILLISECONDS.toNanos(throttleTimeMs) | ||||||
| 
 | 
 | ||||||
|   // Notify the socket server that throttling has started for this channel. |   // Notify the socket server that throttling has started for this channel. | ||||||
|   channelThrottlingCallback(new RequestChannel.StartThrottlingResponse(request)) |   channelThrottlingCallback(new RequestChannel.StartThrottlingResponse(request)) | ||||||
|  | @ -47,13 +49,11 @@ class ThrottledChannel(val request: RequestChannel.Request, val time: Time, val | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   override def getDelay(unit: TimeUnit): Long = { |   override def getDelay(unit: TimeUnit): Long = { | ||||||
|     unit.convert(endTime - time.milliseconds, TimeUnit.MILLISECONDS) |     unit.convert(endTimeNanos - time.nanoseconds(), TimeUnit.NANOSECONDS) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   override def compareTo(d: Delayed): Int = { |   override def compareTo(d: Delayed): Int = { | ||||||
|     val other = d.asInstanceOf[ThrottledChannel] |     val other = d.asInstanceOf[ThrottledChannel] | ||||||
|     if (this.endTime < other.endTime) -1 |     java.lang.Long.compare(this.endTimeNanos, other.endTimeNanos) | ||||||
|     else if (this.endTime > other.endTime) 1 |  | ||||||
|     else 0 |  | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  | @ -122,12 +122,8 @@ private[timer] class TimerTaskList(taskCounter: AtomicInteger) extends Delayed { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   def compareTo(d: Delayed): Int = { |   def compareTo(d: Delayed): Int = { | ||||||
| 
 |  | ||||||
|     val other = d.asInstanceOf[TimerTaskList] |     val other = d.asInstanceOf[TimerTaskList] | ||||||
| 
 |     java.lang.Long.compare(getExpiration, other.getExpiration) | ||||||
|     if(getExpiration < other.getExpiration) -1 |  | ||||||
|     else if(getExpiration > other.getExpiration) 1 |  | ||||||
|     else 0 |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | @ -159,7 +155,7 @@ private[timer] class TimerTaskEntry(val timerTask: TimerTask, val expirationMs: | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   override def compare(that: TimerTaskEntry): Int = { |   override def compare(that: TimerTaskEntry): Int = { | ||||||
|     this.expirationMs compare that.expirationMs |     java.lang.Long.compare(expirationMs, that.expirationMs) | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -18,7 +18,9 @@ import java.time.Duration | ||||||
| import java.util.concurrent.TimeUnit | import java.util.concurrent.TimeUnit | ||||||
| import java.util.{Collections, HashMap, Properties} | import java.util.{Collections, HashMap, Properties} | ||||||
| 
 | 
 | ||||||
|  | import com.yammer.metrics.core.{Histogram, Meter} | ||||||
| import kafka.api.QuotaTestClients._ | import kafka.api.QuotaTestClients._ | ||||||
|  | import kafka.metrics.KafkaYammerMetrics | ||||||
| import kafka.server.{ClientQuotaManager, ClientQuotaManagerConfig, DynamicConfig, KafkaConfig, KafkaServer, QuotaType} | import kafka.server.{ClientQuotaManager, ClientQuotaManagerConfig, DynamicConfig, KafkaConfig, KafkaServer, QuotaType} | ||||||
| import kafka.utils.TestUtils | import kafka.utils.TestUtils | ||||||
| import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} | import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} | ||||||
|  | @ -26,10 +28,13 @@ import org.apache.kafka.clients.producer._ | ||||||
| import org.apache.kafka.clients.producer.internals.ErrorLoggingCallback | import org.apache.kafka.clients.producer.internals.ErrorLoggingCallback | ||||||
| import org.apache.kafka.common.{Metric, MetricName, TopicPartition} | import org.apache.kafka.common.{Metric, MetricName, TopicPartition} | ||||||
| import org.apache.kafka.common.metrics.{KafkaMetric, Quota} | import org.apache.kafka.common.metrics.{KafkaMetric, Quota} | ||||||
|  | import org.apache.kafka.common.protocol.ApiKeys | ||||||
| import org.apache.kafka.common.security.auth.KafkaPrincipal | import org.apache.kafka.common.security.auth.KafkaPrincipal | ||||||
| import org.junit.Assert._ | import org.junit.Assert._ | ||||||
| import org.junit.{Before, Test} | import org.junit.{Before, Test} | ||||||
|  | import org.scalatest.Assertions.fail | ||||||
| 
 | 
 | ||||||
|  | import scala.collection.Map | ||||||
| import scala.jdk.CollectionConverters._ | import scala.jdk.CollectionConverters._ | ||||||
| 
 | 
 | ||||||
| abstract class BaseQuotaTest extends IntegrationTestHarness { | abstract class BaseQuotaTest extends IntegrationTestHarness { | ||||||
|  | @ -186,15 +191,11 @@ abstract class QuotaTestClients(topic: String, | ||||||
|                                 val producer: KafkaProducer[Array[Byte], Array[Byte]], |                                 val producer: KafkaProducer[Array[Byte], Array[Byte]], | ||||||
|                                 val consumer: KafkaConsumer[Array[Byte], Array[Byte]]) { |                                 val consumer: KafkaConsumer[Array[Byte], Array[Byte]]) { | ||||||
| 
 | 
 | ||||||
|   def userPrincipal: KafkaPrincipal |  | ||||||
|   def overrideQuotas(producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit |   def overrideQuotas(producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit | ||||||
|   def removeQuotaOverrides(): Unit |   def removeQuotaOverrides(): Unit | ||||||
| 
 | 
 | ||||||
|   def quotaMetricTags(clientId: String): Map[String, String] |   protected def userPrincipal: KafkaPrincipal | ||||||
| 
 |   protected def quotaMetricTags(clientId: String): Map[String, String] | ||||||
|   def quota(quotaManager: ClientQuotaManager, userPrincipal: KafkaPrincipal, clientId: String): Quota = { |  | ||||||
|     quotaManager.quota(userPrincipal, clientId) |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   def produceUntilThrottled(maxRecords: Int, waitForRequestCompletion: Boolean = true): Int = { |   def produceUntilThrottled(maxRecords: Int, waitForRequestCompletion: Boolean = true): Int = { | ||||||
|     var numProduced = 0 |     var numProduced = 0 | ||||||
|  | @ -235,19 +236,38 @@ abstract class QuotaTestClients(topic: String, | ||||||
|     numConsumed |     numConsumed | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   def verifyProduceThrottle(expectThrottle: Boolean, verifyClientMetric: Boolean = true): Unit = { |   private def quota(quotaManager: ClientQuotaManager, userPrincipal: KafkaPrincipal, clientId: String): Quota = { | ||||||
|  |     quotaManager.quota(userPrincipal, clientId) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private def verifyThrottleTimeRequestChannelMetric(apiKey: ApiKeys, metricNameSuffix: String, | ||||||
|  |                                                      clientId: String, expectThrottle: Boolean): Unit = { | ||||||
|  |     val throttleTimeMs = brokerRequestMetricsThrottleTimeMs(apiKey, metricNameSuffix) | ||||||
|  |     if (expectThrottle) | ||||||
|  |       assertTrue(s"Client with id=$clientId should have been throttled, $throttleTimeMs", throttleTimeMs > 0) | ||||||
|  |     else | ||||||
|  |       assertEquals(s"Client with id=$clientId should not have been throttled", 0.0, throttleTimeMs, 0.0) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   def verifyProduceThrottle(expectThrottle: Boolean, verifyClientMetric: Boolean = true, | ||||||
|  |                             verifyRequestChannelMetric: Boolean = true): Unit = { | ||||||
|     verifyThrottleTimeMetric(QuotaType.Produce, producerClientId, expectThrottle) |     verifyThrottleTimeMetric(QuotaType.Produce, producerClientId, expectThrottle) | ||||||
|  |     if (verifyRequestChannelMetric) | ||||||
|  |       verifyThrottleTimeRequestChannelMetric(ApiKeys.PRODUCE, "", producerClientId, expectThrottle) | ||||||
|     if (verifyClientMetric) |     if (verifyClientMetric) | ||||||
|       verifyProducerClientThrottleTimeMetric(expectThrottle) |       verifyProducerClientThrottleTimeMetric(expectThrottle) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   def verifyConsumeThrottle(expectThrottle: Boolean, verifyClientMetric: Boolean = true): Unit = { |   def verifyConsumeThrottle(expectThrottle: Boolean, verifyClientMetric: Boolean = true, | ||||||
|  |                             verifyRequestChannelMetric: Boolean = true): Unit = { | ||||||
|     verifyThrottleTimeMetric(QuotaType.Fetch, consumerClientId, expectThrottle) |     verifyThrottleTimeMetric(QuotaType.Fetch, consumerClientId, expectThrottle) | ||||||
|  |     if (verifyRequestChannelMetric) | ||||||
|  |       verifyThrottleTimeRequestChannelMetric(ApiKeys.FETCH, "Consumer", consumerClientId, expectThrottle) | ||||||
|     if (verifyClientMetric) |     if (verifyClientMetric) | ||||||
|       verifyConsumerClientThrottleTimeMetric(expectThrottle) |       verifyConsumerClientThrottleTimeMetric(expectThrottle) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   def verifyThrottleTimeMetric(quotaType: QuotaType, clientId: String, expectThrottle: Boolean): Unit = { |   private def verifyThrottleTimeMetric(quotaType: QuotaType, clientId: String, expectThrottle: Boolean): Unit = { | ||||||
|     val throttleMetricValue = metricValue(throttleMetric(quotaType, clientId)) |     val throttleMetricValue = metricValue(throttleMetric(quotaType, clientId)) | ||||||
|     if (expectThrottle) { |     if (expectThrottle) { | ||||||
|       assertTrue(s"Client with id=$clientId should have been throttled", throttleMetricValue > 0) |       assertTrue(s"Client with id=$clientId should have been throttled", throttleMetricValue > 0) | ||||||
|  | @ -256,7 +276,7 @@ abstract class QuotaTestClients(topic: String, | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   def throttleMetricName(quotaType: QuotaType, clientId: String): MetricName = { |   private def throttleMetricName(quotaType: QuotaType, clientId: String): MetricName = { | ||||||
|     leaderNode.metrics.metricName("throttle-time", |     leaderNode.metrics.metricName("throttle-time", | ||||||
|       quotaType.toString, |       quotaType.toString, | ||||||
|       quotaMetricTags(clientId).asJava) |       quotaMetricTags(clientId).asJava) | ||||||
|  | @ -266,12 +286,28 @@ abstract class QuotaTestClients(topic: String, | ||||||
|     leaderNode.metrics.metrics.get(throttleMetricName(quotaType, clientId)) |     leaderNode.metrics.metrics.get(throttleMetricName(quotaType, clientId)) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   private def brokerRequestMetricsThrottleTimeMs(apiKey: ApiKeys, metricNameSuffix: String): Double = { | ||||||
|  |     def yammerMetricValue(name: String): Double = { | ||||||
|  |       val allMetrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala | ||||||
|  |       val (_, metric) = allMetrics.find { case (metricName, _) => | ||||||
|  |         metricName.getMBeanName.startsWith(name) | ||||||
|  |       }.getOrElse(fail(s"Unable to find broker metric $name: allMetrics: ${allMetrics.keySet.map(_.getMBeanName)}")) | ||||||
|  |       metric match { | ||||||
|  |         case m: Meter => m.count.toDouble | ||||||
|  |         case m: Histogram => m.max | ||||||
|  |         case m => fail(s"Unexpected broker metric of class ${m.getClass}") | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     yammerMetricValue(s"kafka.network:type=RequestMetrics,name=ThrottleTimeMs,request=${apiKey.name}$metricNameSuffix") | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   def exemptRequestMetric: KafkaMetric = { |   def exemptRequestMetric: KafkaMetric = { | ||||||
|     val metricName = leaderNode.metrics.metricName("exempt-request-time", QuotaType.Request.toString, "") |     val metricName = leaderNode.metrics.metricName("exempt-request-time", QuotaType.Request.toString, "") | ||||||
|     leaderNode.metrics.metrics.get(metricName) |     leaderNode.metrics.metrics.get(metricName) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   def verifyProducerClientThrottleTimeMetric(expectThrottle: Boolean): Unit = { |   private def verifyProducerClientThrottleTimeMetric(expectThrottle: Boolean): Unit = { | ||||||
|     val tags = new HashMap[String, String] |     val tags = new HashMap[String, String] | ||||||
|     tags.put("client-id", producerClientId) |     tags.put("client-id", producerClientId) | ||||||
|     val avgMetric = producer.metrics.get(new MetricName("produce-throttle-time-avg", "producer-metrics", "", tags)) |     val avgMetric = producer.metrics.get(new MetricName("produce-throttle-time-avg", "producer-metrics", "", tags)) | ||||||
|  |  | ||||||
|  | @ -264,11 +264,16 @@ class CustomQuotaCallbackTest extends IntegrationTestHarness with SaslSetup { | ||||||
|     def produceConsume(expectProduceThrottle: Boolean, expectConsumeThrottle: Boolean): Unit = { |     def produceConsume(expectProduceThrottle: Boolean, expectConsumeThrottle: Boolean): Unit = { | ||||||
|       val numRecords = 1000 |       val numRecords = 1000 | ||||||
|       val produced = produceUntilThrottled(numRecords, waitForRequestCompletion = false) |       val produced = produceUntilThrottled(numRecords, waitForRequestCompletion = false) | ||||||
|       verifyProduceThrottle(expectProduceThrottle, verifyClientMetric = false) |       // don't verify request channel metrics as it's difficult to write non flaky assertions | ||||||
|  |       // given the specifics of this test (throttle metric removal followed by produce/consume | ||||||
|  |       // until throttled) | ||||||
|  |       verifyProduceThrottle(expectProduceThrottle, verifyClientMetric = false, | ||||||
|  |         verifyRequestChannelMetric = false) | ||||||
|       // make sure there are enough records on the topic to test consumer throttling |       // make sure there are enough records on the topic to test consumer throttling | ||||||
|       produceWithoutThrottle(topic, numRecords - produced) |       produceWithoutThrottle(topic, numRecords - produced) | ||||||
|       consumeUntilThrottled(numRecords, waitForRequestCompletion = false) |       consumeUntilThrottled(numRecords, waitForRequestCompletion = false) | ||||||
|       verifyConsumeThrottle(expectConsumeThrottle, verifyClientMetric = false) |       verifyConsumeThrottle(expectConsumeThrottle, verifyClientMetric = false, | ||||||
|  |         verifyRequestChannelMetric = false) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     def removeThrottleMetrics(): Unit = { |     def removeThrottleMetrics(): Unit = { | ||||||
|  |  | ||||||
|  | @ -1841,8 +1841,8 @@ class KafkaApisTest { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private def expectNoThrottling(): Capture[RequestChannel.Response] = { |   private def expectNoThrottling(): Capture[RequestChannel.Response] = { | ||||||
|     EasyMock.expect(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(EasyMock.anyObject[RequestChannel.Request]())) |     EasyMock.expect(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(EasyMock.anyObject[RequestChannel.Request](), | ||||||
|       .andReturn(0) |       EasyMock.anyObject[Long])).andReturn(0) | ||||||
|     EasyMock.expect(clientRequestQuotaManager.throttle(EasyMock.anyObject[RequestChannel.Request](), EasyMock.eq(0), |     EasyMock.expect(clientRequestQuotaManager.throttle(EasyMock.anyObject[RequestChannel.Request](), EasyMock.eq(0), | ||||||
|       EasyMock.anyObject[RequestChannel.Response => Unit]())) |       EasyMock.anyObject[RequestChannel.Response => Unit]())) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue