diff --git a/tests/kafkatest/services/verifiable_producer.py b/tests/kafkatest/services/verifiable_producer.py index f2ea421de33..0096a34875e 100644 --- a/tests/kafkatest/services/verifiable_producer.py +++ b/tests/kafkatest/services/verifiable_producer.py @@ -52,12 +52,8 @@ class VerifiableProducer(BackgroundThreadService): num_nodes = 1 * is_int_with_prefix recommended if num_nodes > 1, because otherwise each producer will produce exactly same messages, and validation may miss missing messages. - :param compression_types: If None, all producers will not use compression; or a list of one or - more compression types (including "none"). Each producer will pick a compression type - from the list in round-robin fashion. Example: compression_types = ["none", "snappy"] and - num_nodes = 3, then producer 1 and 2 will not use compression, and producer 3 will use - compression type = snappy. If in this example, num_nodes is 1, then first (and only) - producer will not use compression. + :param compression_types: If None, all producers will not use compression; or a list of + compression types, one per producer (could be "none"). """ super(VerifiableProducer, self).__init__(context, num_nodes) @@ -67,30 +63,36 @@ class VerifiableProducer(BackgroundThreadService): self.throughput = throughput self.message_validator = message_validator self.compression_types = compression_types + if self.compression_types is not None: + assert len(self.compression_types) == num_nodes, "Specify one compression type per node" + + self.security_config = self.kafka.security_config.client_config() for node in self.nodes: node.version = version self.acked_values = [] self.not_acked_values = [] self.produced_count = {} - self.prop_file = "" + + def prop_file(self, node): + idx = self.idx(node) + prop_file = str(self.security_config) + if self.compression_types is not None: + compression_index = idx - 1 + self.logger.info("VerifiableProducer (index = %d) will use compression type = %s", idx, + self.compression_types[compression_index]) + prop_file += "\ncompression.type=%s\n" % self.compression_types[compression_index] + return prop_file def _worker(self, idx, node): node.account.ssh("mkdir -p %s" % VerifiableProducer.PERSISTENT_ROOT, allow_fail=False) # Create and upload log properties - self.security_config = self.kafka.security_config.client_config(self.prop_file) - producer_prop_file = str(self.security_config) log_config = self.render('tools_log4j.properties', log_file=VerifiableProducer.LOG_FILE) node.account.create_file(VerifiableProducer.LOG4J_CONFIG, log_config) # Create and upload config file - if self.compression_types is not None: - compression_index = (idx - 1) % len(self.compression_types) - self.logger.info("VerifiableProducer (index = %d) will use compression type = %s", idx, - self.compression_types[compression_index]) - producer_prop_file += "\ncompression.type=%s\n" % self.compression_types[compression_index] - + producer_prop_file = self.prop_file(node) self.logger.info("verifiable_producer.properties:") self.logger.info(producer_prop_file) node.account.create_file(VerifiableProducer.CONFIG_FILE, producer_prop_file) @@ -197,7 +199,7 @@ class VerifiableProducer(BackgroundThreadService): def each_produced_at_least(self, count): with self.lock: - for idx in range(1, self.num_nodes): + for idx in range(1, self.num_nodes + 1): if self.produced_count.get(idx) is None or self.produced_count[idx] < count: return False return True