KAFKA-15045: (KIP-924 pt. 4) Generify rack graph solving utilities (#15956)

The graph solving utilities are currently hardcoded to work with ClientState, but don't actually depend on anything in those state classes.

This change allows the MinTrafficGraphConstructor and BalanceSubtopologyGraphConstructor to be reused with KafkaStreamsStates instead.

Reviewers: Anna Sophie Blee-Goldman <ableegoldman@apache.org>, Almog Gavra <almog@responsive.dev>
This commit is contained in:
Antoine Pourchet 2024-05-16 12:37:59 -06:00 committed by GitHub
parent 056d232f4e
commit fafa3c76dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 40 additions and 40 deletions

View File

@ -35,7 +35,7 @@ import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction;
public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstructor {
public class BalanceSubtopologyGraphConstructor<T> implements RackAwareGraphConstructor<T> {
private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
@ -71,10 +71,10 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
public Graph<Integer> constructTaskGraph(
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,
@ -86,7 +86,7 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
final Graph<Integer> graph = new Graph<>();
for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) {
originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}
@ -122,12 +122,12 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
final Graph<Integer> graph,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<T, TaskId> hasAssignedTask
) {
final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = new TreeMap<>(tasksForTopicGroup);
final Set<TaskId> taskIdSet = new HashSet<>(taskIdList);
@ -170,10 +170,10 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
final Graph<Integer> graph,
final List<TaskId> taskIdList,
final List<UUID> clientList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,

View File

@ -28,7 +28,7 @@ import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction;
public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
public class MinTrafficGraphConstructor<T> implements RackAwareGraphConstructor<T> {
@Override
public int getSinkNodeID(
@ -53,10 +53,10 @@ public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
public Graph<Integer> constructTaskGraph(
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,
@ -66,7 +66,7 @@ public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
final Graph<Integer> graph = new Graph<>();
for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) {
originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}
@ -122,12 +122,12 @@ public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
final Graph<Integer> graph,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<T, TaskId> hasAssignedTask
) {
int tasksAssigned = 0;
boolean taskMoved = false;

View File

@ -33,7 +33,7 @@ import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssi
/**
* Construct graph for rack aware task assignor
*/
public interface RackAwareGraphConstructor {
public interface RackAwareGraphConstructor<T> {
int SOURCE_ID = -1;
int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup);
@ -45,10 +45,10 @@ public interface RackAwareGraphConstructor {
Graph<Integer> constructTaskGraph(
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,
@ -59,24 +59,24 @@ public interface RackAwareGraphConstructor {
final Graph<Integer> graph,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask);
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<T, TaskId> hasAssignedTask);
default KeyValue<Boolean, Integer> assignTaskToClient(
final Graph<Integer> graph,
final TaskId taskId,
final int taskNodeId,
final int topicGroupIndex,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask
) {
int tasksAssigned = 0;
boolean taskMoved = false;
@ -104,9 +104,9 @@ public interface RackAwareGraphConstructor {
default void validateAssignedTask(
final List<TaskId> taskIdList,
final int tasksAssigned,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask
final BiPredicate<T, TaskId> hasAssignedTask
) {
// Validate task assigned
if (tasksAssigned != taskIdList.size()) {
@ -117,7 +117,7 @@ public interface RackAwareGraphConstructor {
// Validate original assigned task number matches
final Map<UUID, Integer> assignedTaskNumber = new HashMap<>();
for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) {
assignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}

View File

@ -25,12 +25,12 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignorConfigura
public class RackAwareGraphConstructorFactory {
static RackAwareGraphConstructor create(final AssignmentConfigs assignmentConfigs, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
static <T> RackAwareGraphConstructor<T> create(final AssignmentConfigs assignmentConfigs, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
switch (assignmentConfigs.rackAwareAssignmentStrategy) {
case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_MIN_TRAFFIC:
return new MinTrafficGraphConstructor();
return new MinTrafficGraphConstructor<T>();
case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_BALANCE_SUBTOPOLOGY:
return new BalanceSubtopologyGraphConstructor(tasksForTopicGroup);
return new BalanceSubtopologyGraphConstructor<T>(tasksForTopicGroup);
default:
throw new IllegalArgumentException("Rack aware assignment is disabled");
}

View File

@ -325,7 +325,7 @@ public class RackAwareTaskAssignor {
}
final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
final List<TaskId> taskIdList = new ArrayList<>(tasks);
final Graph<Integer> graph = new MinTrafficGraphConstructor()
final Graph<Integer> graph = new MinTrafficGraphConstructor<ClientState>()
.constructTaskGraph(
clientList,
taskIdList,
@ -373,7 +373,7 @@ public class RackAwareTaskAssignor {
final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
final Map<TaskId, UUID> taskClientMap = new HashMap<>();
final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
final RackAwareGraphConstructor graphConstructor = RackAwareGraphConstructorFactory.create(assignmentConfigs, tasksForTopicGroup);
final RackAwareGraphConstructor<ClientState> graphConstructor = RackAwareGraphConstructorFactory.create(assignmentConfigs, tasksForTopicGroup);
final Graph<Integer> graph = graphConstructor.constructTaskGraph(
clientList,
taskIdList,
@ -419,7 +419,7 @@ public class RackAwareTaskAssignor {
boolean taskMoved = true;
int round = 0;
final RackAwareGraphConstructor graphConstructor = new MinTrafficGraphConstructor();
final RackAwareGraphConstructor<ClientState> graphConstructor = new MinTrafficGraphConstructor<>();
while (taskMoved && round < STANDBY_OPTIMIZER_MAX_ITERATION) {
taskMoved = false;
round++;

View File

@ -68,7 +68,7 @@ public class RackAwareGraphConstructorTest {
private final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup = getTasksForTopicGroup(TP_SIZE,
PARTITION_SIZE);
private RackAwareGraphConstructor constructor;
private RackAwareGraphConstructor<ClientState> constructor;
@Parameter
public String constructorType;
@ -86,9 +86,9 @@ public class RackAwareGraphConstructorTest {
randomAssignTasksToClient(taskIdList, clientStateMap);
if (constructorType.equals(MIN_COST)) {
constructor = new MinTrafficGraphConstructor();
constructor = new MinTrafficGraphConstructor<>();
} else if (constructorType.equals(BALANCE_SUBTOPOLOGY)) {
constructor = new BalanceSubtopologyGraphConstructor(tasksForTopicGroup);
constructor = new BalanceSubtopologyGraphConstructor<>(tasksForTopicGroup);
}
graph = constructor.constructTaskGraph(
clientList, taskIdList, clientStateMap, taskClientMap, originalAssignedTaskNumber, ClientState::hasAssignedTask, this::getCost, 10, 1, false, false);