KAFKA-15045: (KIP-924 pt. 4) Generify rack graph solving utilities (#15956)

The graph solving utilities are currently hardcoded to work with ClientState, but don't actually depend on anything in those state classes.

This change allows the MinTrafficGraphConstructor and BalanceSubtopologyGraphConstructor to be reused with KafkaStreamsStates instead.

Reviewers: Anna Sophie Blee-Goldman <ableegoldman@apache.org>, Almog Gavra <almog@responsive.dev>
This commit is contained in:
Antoine Pourchet 2024-05-16 12:37:59 -06:00 committed by GitHub
parent 056d232f4e
commit fafa3c76dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 40 additions and 40 deletions

View File

@ -35,7 +35,7 @@ import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction; import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction;
public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstructor { public class BalanceSubtopologyGraphConstructor<T> implements RackAwareGraphConstructor<T> {
private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup; private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
@ -71,10 +71,10 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
public Graph<Integer> constructTaskGraph( public Graph<Integer> constructTaskGraph(
final List<UUID> clientList, final List<UUID> clientList,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask, final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction, final CostFunction costFunction,
final int trafficCost, final int trafficCost,
final int nonOverlapCost, final int nonOverlapCost,
@ -86,7 +86,7 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
final Graph<Integer> graph = new Graph<>(); final Graph<Integer> graph = new Graph<>();
for (final TaskId taskId : taskIdList) { for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) { for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) { if (hasAssignedTask.test(clientState.getValue(), taskId)) {
originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum); originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
} }
@ -122,12 +122,12 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
final Graph<Integer> graph, final Graph<Integer> graph,
final List<UUID> clientList, final List<UUID> clientList,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask, final BiConsumer<T, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask, final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask final BiPredicate<T, TaskId> hasAssignedTask
) { ) {
final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = new TreeMap<>(tasksForTopicGroup); final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = new TreeMap<>(tasksForTopicGroup);
final Set<TaskId> taskIdSet = new HashSet<>(taskIdList); final Set<TaskId> taskIdSet = new HashSet<>(taskIdList);
@ -170,10 +170,10 @@ public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstru
final Graph<Integer> graph, final Graph<Integer> graph,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final List<UUID> clientList, final List<UUID> clientList,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask, final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction, final CostFunction costFunction,
final int trafficCost, final int trafficCost,
final int nonOverlapCost, final int nonOverlapCost,

View File

@ -28,7 +28,7 @@ import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction; import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction;
public class MinTrafficGraphConstructor implements RackAwareGraphConstructor { public class MinTrafficGraphConstructor<T> implements RackAwareGraphConstructor<T> {
@Override @Override
public int getSinkNodeID( public int getSinkNodeID(
@ -53,10 +53,10 @@ public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
public Graph<Integer> constructTaskGraph( public Graph<Integer> constructTaskGraph(
final List<UUID> clientList, final List<UUID> clientList,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask, final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction, final CostFunction costFunction,
final int trafficCost, final int trafficCost,
final int nonOverlapCost, final int nonOverlapCost,
@ -66,7 +66,7 @@ public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
final Graph<Integer> graph = new Graph<>(); final Graph<Integer> graph = new Graph<>();
for (final TaskId taskId : taskIdList) { for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) { for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) { if (hasAssignedTask.test(clientState.getValue(), taskId)) {
originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum); originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
} }
@ -122,12 +122,12 @@ public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
final Graph<Integer> graph, final Graph<Integer> graph,
final List<UUID> clientList, final List<UUID> clientList,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask, final BiConsumer<T, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask, final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask final BiPredicate<T, TaskId> hasAssignedTask
) { ) {
int tasksAssigned = 0; int tasksAssigned = 0;
boolean taskMoved = false; boolean taskMoved = false;

View File

@ -33,7 +33,7 @@ import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssi
/** /**
* Construct graph for rack aware task assignor * Construct graph for rack aware task assignor
*/ */
public interface RackAwareGraphConstructor { public interface RackAwareGraphConstructor<T> {
int SOURCE_ID = -1; int SOURCE_ID = -1;
int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup); int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup);
@ -45,10 +45,10 @@ public interface RackAwareGraphConstructor {
Graph<Integer> constructTaskGraph( Graph<Integer> constructTaskGraph(
final List<UUID> clientList, final List<UUID> clientList,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask, final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction, final CostFunction costFunction,
final int trafficCost, final int trafficCost,
final int nonOverlapCost, final int nonOverlapCost,
@ -59,24 +59,24 @@ public interface RackAwareGraphConstructor {
final Graph<Integer> graph, final Graph<Integer> graph,
final List<UUID> clientList, final List<UUID> clientList,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask, final BiConsumer<T, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask, final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask); final BiPredicate<T, TaskId> hasAssignedTask);
default KeyValue<Boolean, Integer> assignTaskToClient( default KeyValue<Boolean, Integer> assignTaskToClient(
final Graph<Integer> graph, final Graph<Integer> graph,
final TaskId taskId, final TaskId taskId,
final int taskNodeId, final int taskNodeId,
final int topicGroupIndex, final int topicGroupIndex,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final List<UUID> clientList, final List<UUID> clientList,
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final Map<TaskId, UUID> taskClientMap, final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask, final BiConsumer<T, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask final BiConsumer<T, TaskId> unAssignTask
) { ) {
int tasksAssigned = 0; int tasksAssigned = 0;
boolean taskMoved = false; boolean taskMoved = false;
@ -104,9 +104,9 @@ public interface RackAwareGraphConstructor {
default void validateAssignedTask( default void validateAssignedTask(
final List<TaskId> taskIdList, final List<TaskId> taskIdList,
final int tasksAssigned, final int tasksAssigned,
final Map<UUID, ClientState> clientStates, final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber, final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask final BiPredicate<T, TaskId> hasAssignedTask
) { ) {
// Validate task assigned // Validate task assigned
if (tasksAssigned != taskIdList.size()) { if (tasksAssigned != taskIdList.size()) {
@ -117,7 +117,7 @@ public interface RackAwareGraphConstructor {
// Validate original assigned task number matches // Validate original assigned task number matches
final Map<UUID, Integer> assignedTaskNumber = new HashMap<>(); final Map<UUID, Integer> assignedTaskNumber = new HashMap<>();
for (final TaskId taskId : taskIdList) { for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) { for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) { if (hasAssignedTask.test(clientState.getValue(), taskId)) {
assignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum); assignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
} }

View File

@ -25,12 +25,12 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignorConfigura
public class RackAwareGraphConstructorFactory { public class RackAwareGraphConstructorFactory {
static RackAwareGraphConstructor create(final AssignmentConfigs assignmentConfigs, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) { static <T> RackAwareGraphConstructor<T> create(final AssignmentConfigs assignmentConfigs, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
switch (assignmentConfigs.rackAwareAssignmentStrategy) { switch (assignmentConfigs.rackAwareAssignmentStrategy) {
case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_MIN_TRAFFIC: case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_MIN_TRAFFIC:
return new MinTrafficGraphConstructor(); return new MinTrafficGraphConstructor<T>();
case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_BALANCE_SUBTOPOLOGY: case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_BALANCE_SUBTOPOLOGY:
return new BalanceSubtopologyGraphConstructor(tasksForTopicGroup); return new BalanceSubtopologyGraphConstructor<T>(tasksForTopicGroup);
default: default:
throw new IllegalArgumentException("Rack aware assignment is disabled"); throw new IllegalArgumentException("Rack aware assignment is disabled");
} }

View File

@ -325,7 +325,7 @@ public class RackAwareTaskAssignor {
} }
final List<UUID> clientList = new ArrayList<>(clientStates.keySet()); final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
final List<TaskId> taskIdList = new ArrayList<>(tasks); final List<TaskId> taskIdList = new ArrayList<>(tasks);
final Graph<Integer> graph = new MinTrafficGraphConstructor() final Graph<Integer> graph = new MinTrafficGraphConstructor<ClientState>()
.constructTaskGraph( .constructTaskGraph(
clientList, clientList,
taskIdList, taskIdList,
@ -373,7 +373,7 @@ public class RackAwareTaskAssignor {
final List<TaskId> taskIdList = new ArrayList<>(activeTasks); final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
final Map<TaskId, UUID> taskClientMap = new HashMap<>(); final Map<TaskId, UUID> taskClientMap = new HashMap<>();
final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>(); final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
final RackAwareGraphConstructor graphConstructor = RackAwareGraphConstructorFactory.create(assignmentConfigs, tasksForTopicGroup); final RackAwareGraphConstructor<ClientState> graphConstructor = RackAwareGraphConstructorFactory.create(assignmentConfigs, tasksForTopicGroup);
final Graph<Integer> graph = graphConstructor.constructTaskGraph( final Graph<Integer> graph = graphConstructor.constructTaskGraph(
clientList, clientList,
taskIdList, taskIdList,
@ -419,7 +419,7 @@ public class RackAwareTaskAssignor {
boolean taskMoved = true; boolean taskMoved = true;
int round = 0; int round = 0;
final RackAwareGraphConstructor graphConstructor = new MinTrafficGraphConstructor(); final RackAwareGraphConstructor<ClientState> graphConstructor = new MinTrafficGraphConstructor<>();
while (taskMoved && round < STANDBY_OPTIMIZER_MAX_ITERATION) { while (taskMoved && round < STANDBY_OPTIMIZER_MAX_ITERATION) {
taskMoved = false; taskMoved = false;
round++; round++;

View File

@ -68,7 +68,7 @@ public class RackAwareGraphConstructorTest {
private final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>(); private final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup = getTasksForTopicGroup(TP_SIZE, private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup = getTasksForTopicGroup(TP_SIZE,
PARTITION_SIZE); PARTITION_SIZE);
private RackAwareGraphConstructor constructor; private RackAwareGraphConstructor<ClientState> constructor;
@Parameter @Parameter
public String constructorType; public String constructorType;
@ -86,9 +86,9 @@ public class RackAwareGraphConstructorTest {
randomAssignTasksToClient(taskIdList, clientStateMap); randomAssignTasksToClient(taskIdList, clientStateMap);
if (constructorType.equals(MIN_COST)) { if (constructorType.equals(MIN_COST)) {
constructor = new MinTrafficGraphConstructor(); constructor = new MinTrafficGraphConstructor<>();
} else if (constructorType.equals(BALANCE_SUBTOPOLOGY)) { } else if (constructorType.equals(BALANCE_SUBTOPOLOGY)) {
constructor = new BalanceSubtopologyGraphConstructor(tasksForTopicGroup); constructor = new BalanceSubtopologyGraphConstructor<>(tasksForTopicGroup);
} }
graph = constructor.constructTaskGraph( graph = constructor.constructTaskGraph(
clientList, taskIdList, clientStateMap, taskClientMap, originalAssignedTaskNumber, ClientState::hasAssignedTask, this::getCost, 10, 1, false, false); clientList, taskIdList, clientStateMap, taskClientMap, originalAssignedTaskNumber, ClientState::hasAssignedTask, this::getCost, 10, 1, false, false);