KAFKA-15022: [2/N] introduce graph to compute min cost (#13996)

Part of KIP-925.

Reviewers: Matthias J. Sax <matthias@confluent.io>
This commit is contained in:
Hao Li 2023-07-20 16:17:47 -07:00 committed by GitHub
parent 125dbb9286
commit 6bb88ae2f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 793 additions and 2 deletions

View File

@ -0,0 +1,377 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
public class Graph<V extends Comparable<V>> {
public class Edge {
final V destination;
final int capacity;
final int cost;
int residualFlow;
int flow;
Edge counterEdge;
boolean forwardEdge;
public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) {
this(destination, capacity, cost, residualFlow, flow, true);
}
public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow,
final boolean forwardEdge) {
Objects.requireNonNull(destination);
if (capacity < 0) {
throw new IllegalArgumentException("Edge capacity cannot be negative");
}
if (flow > capacity) {
throw new IllegalArgumentException(String.format("Edge flow %d cannot exceed capacity %d",
flow, capacity));
}
this.destination = destination;
this.capacity = capacity;
this.cost = cost;
this.residualFlow = residualFlow;
this.flow = flow;
this.forwardEdge = forwardEdge;
}
@Override
public boolean equals(final Object other) {
if (this == other) {
return true;
}
if (other == null || other.getClass() != getClass()) {
return false;
}
final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other;
return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity
&& cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow
&& forwardEdge == otherEdge.forwardEdge;
}
@Override
public int hashCode() {
return Objects.hash(destination, capacity, cost, residualFlow, flow, forwardEdge);
}
@Override
public String toString() {
return "{destination= " + destination + ", capacity=" + capacity + ", cost=" + cost
+ ", residualFlow=" + residualFlow + ", flow=" + flow + ", forwardEdge=" + forwardEdge;
}
}
private final SortedMap<V, SortedMap<V, Edge>> adjList = new TreeMap<>();
private final SortedSet<V> nodes = new TreeSet<>();
private final boolean isResidualGraph;
private V sourceNode, sinkNode;
public Graph() {
this(false);
}
private Graph(final boolean isResidualGraph) {
this.isResidualGraph = isResidualGraph;
}
public void addEdge(final V u, final V v, final int capacity, final int cost, final int flow) {
addEdge(u, new Edge(v, capacity, cost, capacity - flow, flow));
}
public Set<V> nodes() {
return nodes;
}
public Map<V, Edge> edges(final V node) {
return adjList.get(node);
}
public boolean isResidualGraph() {
return isResidualGraph;
}
public void setSourceNode(final V node) {
sourceNode = node;
}
public void setSinkNode(final V node) {
sinkNode = node;
}
public int totalCost() {
int totalCost = 0;
for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
final SortedMap<V, Edge> edges = nodeEdges.getValue();
for (final Entry<V, Edge> nodeEdge : edges.entrySet()) {
totalCost += nodeEdge.getValue().cost * nodeEdge.getValue().flow;
}
}
return totalCost;
}
private void addEdge(final V u, final Edge edge) {
if (!isResidualGraph) {
// Check if there's already an edge from u to v
final Map<V, Edge> edgeMap = adjList.get(edge.destination);
if (edgeMap != null && edgeMap.containsKey(u)) {
throw new IllegalArgumentException(
"There is already an edge from " + edge.destination
+ " to " + u + ". Can not add an edge from " + u + " to " + edge.destination
+ " since there will create a cycle between two nodes");
}
}
adjList.computeIfAbsent(u, set -> new TreeMap<>()).put(edge.destination, edge);
nodes.add(u);
nodes.add(edge.destination);
}
/**
* Get residual graph of this graph.
* Residual graph definition:
* If there is an edge in original graph from u to v with capacity c, cost w and flow f,
* then in the new graph there are two edges e1 and e2. e1 is from u to v with capacity c - f,
* cost w and flow f. e2 is from v to u with capacity f, cost -w and flow 0.
*
* @return Residual graph
*/
public Graph<V> residualGraph() {
if (isResidualGraph) {
return this;
}
final Graph<V> residualGraph = new Graph<>(true);
for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
final V node = nodeEdges.getKey();
final SortedMap<V, Edge> edges = nodeEdges.getValue();
for (final Entry<V, Edge> nodeEdge : edges.entrySet()) {
final Edge edge = nodeEdge.getValue();
final Edge forwardEdge = new Edge(edge.destination, edge.capacity, edge.cost, edge.capacity - edge.flow, edge.flow);
final Edge backwardEdge = new Edge(node, edge.capacity, edge.cost * -1, edge.flow, 0, false);
forwardEdge.counterEdge = backwardEdge;
backwardEdge.counterEdge = forwardEdge;
residualGraph.addEdge(node, forwardEdge);
residualGraph.addEdge(edge.destination, backwardEdge);
}
}
return residualGraph;
}
/**
* Solve min cost flow with cycle canceling algorithm.
*/
public void solveMinCostFlow() {
validateMinCostGraph();
final Graph<V> residualGraph = residualGraph();
residualGraph.cancelNegativeCycles();
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
final V node = nodeEdges.getKey();
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
final V destination = nodeEdge.getKey();
final Edge edge = nodeEdge.getValue();
final Edge residualEdge = residualGraph.adjList.get(node).get(destination);
edge.flow = residualEdge.flow;
edge.residualFlow = residualEdge.residualFlow;
}
}
}
private void populateInOutFlow(final Map<V, Long> inFlow, final Map<V, Long> outFlow) {
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
final V node = nodeEdges.getKey();
if (node.equals(sinkNode)) {
throw new IllegalStateException("Sink node " + sinkNode + " shouldn't have output");
}
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
final V destination = nodeEdge.getKey();
if (destination.equals(sourceNode)) {
throw new IllegalStateException("Source node " + sourceNode + " shouldn't have input " + node);
}
final Edge edge = nodeEdge.getValue();
Long count = outFlow.get(node);
if (count == null) {
outFlow.put(node, (long) edge.flow);
} else {
outFlow.put(node, count + edge.flow);
}
count = inFlow.get(destination);
if (count == null) {
inFlow.put(destination, (long) edge.flow);
} else {
inFlow.put(destination, count + edge.flow);
}
}
}
}
private void validateMinCostGraph() {
if (isResidualGraph) {
throw new IllegalStateException("Should not be residual graph to solve min cost flow");
}
/*
Check provided flow satisfying below constraints:
1. Input flow and output flow for each node should be the same except for source and destination node
2. Output flow of source and input flow of destination should be the same
*/
final Map<V, Long> inFlow = new HashMap<>();
final Map<V, Long> outFlow = new HashMap<>();
populateInOutFlow(inFlow, outFlow);
for (final Entry<V, Long> in : inFlow.entrySet()) {
if (in.getKey().equals(sourceNode) || in.getKey().equals(sinkNode)) {
continue;
}
final Long out = outFlow.get(in.getKey());
if (!Objects.equals(in.getValue(), out)) {
throw new IllegalStateException("Input flow for node " + in.getKey() + " is " +
in.getValue() + " which doesn't match output flow " + out);
}
}
final Long sourceOutput = outFlow.get(sourceNode);
final Long sinkInput = inFlow.get(sinkNode);
if (!Objects.equals(sourceOutput, sinkInput)) {
throw new IllegalStateException("Output flow for source " + sourceNode + " is " + sourceOutput
+ " which doesn't match input flow " + sinkInput + " for sink " + sinkNode);
}
}
private void cancelNegativeCycles() {
if (!isResidualGraph) {
throw new IllegalStateException("Should be residual graph to cancel negative cycles");
}
boolean cyclePossible = true;
while (cyclePossible) {
cyclePossible = false;
for (final V node : nodes) {
final Map<V, V> parentNodes = new HashMap<>();
final Map<V, Edge> parentEdges = new HashMap<>();
final V possibleNodeInCycle = detectNegativeCycles(node, parentNodes, parentEdges);
if (possibleNodeInCycle == null) {
continue;
}
final Set<V> visited = new HashSet<>();
V nodeInCycle = possibleNodeInCycle;
while (!visited.contains(nodeInCycle)) {
visited.add(nodeInCycle);
nodeInCycle = parentNodes.get(nodeInCycle);
}
cyclePossible = true;
cancelNegativeCycle(nodeInCycle, parentNodes, parentEdges);
}
}
}
private void cancelNegativeCycle(final V nodeInCycle, final Map<V, V> parentNodes, final Map<V, Edge> parentEdges) {
// Start from parentNode since nodeInCyle is used as exit condition in below loops
final V parentNode = parentNodes.get(nodeInCycle);
Edge parentEdge = parentEdges.get(nodeInCycle);
// Find max possible negative flow
int possibleFlow = parentEdge.residualFlow;
for (V curNode = parentNode; curNode != nodeInCycle; curNode = parentNodes.get(curNode)) {
parentEdge = parentEdges.get(curNode);
possibleFlow = Math.min(possibleFlow, parentEdge.residualFlow);
}
// Update graph by removing negative flow
parentEdge = parentEdges.get(nodeInCycle);
Edge counterEdge = parentEdge.counterEdge;
parentEdge.residualFlow -= possibleFlow;
if (parentEdge.forwardEdge) {
parentEdge.flow += possibleFlow;
}
counterEdge.residualFlow += possibleFlow;
if (counterEdge.forwardEdge && counterEdge.flow >= possibleFlow) {
counterEdge.flow -= possibleFlow;
}
for (V curNode = parentNode; curNode != nodeInCycle; curNode = parentNodes.get(curNode)) {
parentEdge = parentEdges.get(curNode);
counterEdge = parentEdge.counterEdge;
parentEdge.residualFlow -= possibleFlow;
if (parentEdge.forwardEdge) {
parentEdge.flow += possibleFlow;
}
counterEdge.residualFlow += possibleFlow;
if (counterEdge.forwardEdge && counterEdge.flow >= possibleFlow) {
counterEdge.flow -= possibleFlow;
}
}
}
/**
* Detect negative cycle using Bellman-ford shortest path algorithm.
* @param source Source node
* @param parentNodes Parent nodes to store negative cycle nodes
* @param parentEdges Parent edges to store negative cycle edges
*
* @return One node which can lead to negative cycle if exists or null if there's no negative cycle
*/
V detectNegativeCycles(final V source, final Map<V, V> parentNodes, final Map<V, Edge> parentEdges) {
// Use long to account for any overflow
final Map<V, Long> distance = nodes.stream().collect(Collectors.toMap(node -> node, node -> (long) Integer.MAX_VALUE));
distance.put(source, 0L);
final int nodeCount = nodes.size();
// Iterate nodeCount iterations since Bellaman-Ford will find shortest path in nodeCount - 1
// iterations. If the distance can still be relaxed in nodeCount iteration, there's a negative
// cycle
for (int i = 0; i < nodeCount; i++) {
// Iterate through all edges
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
final V u = nodeEdges.getKey();
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
final Edge edge = nodeEdge.getValue();
if (edge.residualFlow == 0) {
continue;
}
final V v = edge.destination;
if (distance.get(v) > distance.get(u) + edge.cost) {
if (i == nodeCount - 1) {
return v;
}
distance.put(v, distance.get(u) + edge.cost);
parentNodes.put(v, u);
parentEdges.put(v, edge);
}
}
}
}
return null;
}
}

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package org.apache.kafka.streams.processor.internals; package org.apache.kafka.streams.processor.internals.assignment;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
@ -30,6 +30,7 @@ import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.TopicPartitionInfo; import org.apache.kafka.common.TopicPartitionInfo;
import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.InternalTopicManager;
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
import org.slf4j.Logger; import org.slf4j.Logger;

View File

@ -0,0 +1,414 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
public class GraphTest {
private Graph<Integer> graph;
@Before
public void setUp() {
/*
* Node 0 and 2 are both connected to node 1 and 3. There's a flow of 1 unit from 0 to 1 and 2 to
* 3. The total cost in this case is 5. Min cost should be 2 by flowing 1 unit from 0 to 3 and 2
* to 1
*/
graph = new Graph<>();
graph.addEdge(0, 1, 1, 3, 1);
graph.addEdge(0, 3, 1, 1, 0);
graph.addEdge(2, 1, 1, 1, 0);
graph.addEdge(2, 3, 1, 2, 1);
graph.addEdge(4, 0, 1, 0, 1);
graph.addEdge(4, 2, 1, 0, 1);
graph.addEdge(1, 5, 1, 0, 1);
graph.addEdge(3, 5, 1, 0, 1);
graph.setSourceNode(4);
graph.setSinkNode(5);
}
@Test
public void testBasic() {
final Set<Integer> nodes = graph.nodes();
assertEquals(6, nodes.size());
assertThat(nodes, contains(0, 1, 2, 3, 4, 5));
Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0);
assertEquals(2, edges.size());
assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1));
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
edges = graph.edges(2);
assertEquals(2, edges.size());
assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1));
assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3));
edges = graph.edges(1);
assertEquals(1, edges.size());
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
edges = graph.edges(3);
assertEquals(1, edges.size());
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
edges = graph.edges(4);
assertEquals(2, edges.size());
assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0));
assertEquals(getEdge(2, 1, 0, 0, 1), edges.get(2));
edges = graph.edges(5);
assertNull(edges);
assertFalse(graph.isResidualGraph());
}
@Test
public void testResidualGraph() {
final Graph<Integer> residualGraph = graph.residualGraph();
final Graph<Integer> residualGraph1 = residualGraph.residualGraph();
assertSame(residualGraph1, residualGraph);
final Set<Integer> nodes = residualGraph.nodes();
assertEquals(6, nodes.size());
assertThat(nodes, contains(0, 1, 2, 3, 4, 5));
Map<Integer, Graph<Integer>.Edge> edges = residualGraph.edges(0);
assertEquals(3, edges.size());
assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1));
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4));
edges = residualGraph.edges(2);
assertEquals(3, edges.size());
assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1));
assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3));
assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4));
edges = residualGraph.edges(1);
assertEquals(3, edges.size());
assertEquals(getEdge(0, 1, -3, 1, 0, false), edges.get(0));
assertEquals(getEdge(2, 1, -1, 0, 0, false), edges.get(2));
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
edges = residualGraph.edges(3);
assertEquals(3, edges.size());
assertEquals(getEdge(0, 1, -1, 0, 0, false), edges.get(0));
assertEquals(getEdge(2, 1, -2, 1, 0, false), edges.get(2));
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
assertTrue(residualGraph.isResidualGraph());
}
@Test
public void testInvalidOperation() {
final Graph<Integer> graph1 = new Graph<>();
Exception exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, -1, 0, 0));
assertEquals("Edge capacity cannot be negative", exception.getMessage());
exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, 1, 0, 2));
assertEquals("Edge flow 2 cannot exceed capacity 1", exception.getMessage());
graph1.addEdge(0, 1, 1, 1, 1);
exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(1, 0, 1, 0, 0));
assertEquals("There is already an edge from 0 to 1. Can not add an edge from 1 to 0 since "
+ "there will create a cycle between two nodes", exception.getMessage());
final Graph<Integer> residualGraph = graph1.residualGraph();
exception = assertThrows(IllegalStateException.class, residualGraph::solveMinCostFlow);
assertEquals("Should not be residual graph to solve min cost flow", exception.getMessage());
}
@Test
public void testInvalidSource() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(0, 1, 1, 1, 0);
graph1.addEdge(1, 2, 1, 1, 0);
graph1.setSourceNode(1);
graph1.setSinkNode(2);
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
assertEquals("Source node 1 shouldn't have input 0", exception.getMessage());
}
@Test
public void testInvalidSink() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(0, 1, 1, 1, 0);
graph1.addEdge(1, 2, 1, 1, 0);
graph1.setSourceNode(0);
graph1.setSinkNode(1);
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
assertEquals("Sink node 1 shouldn't have output", exception.getMessage());
}
@Test
public void testInvalidFlow() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(0, 1, 1, 1, 1);
graph1.addEdge(0, 2, 2, 1, 2);
graph1.addEdge(1, 3, 1, 1, 1);
graph1.addEdge(2, 3, 2, 1, 0); // Missing flow from 2 to 3
graph1.setSourceNode(0);
graph1.setSinkNode(3);
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
assertEquals("Input flow for node 2 is 2 which doesn't match output flow 0", exception.getMessage());
}
@Test
public void testMissingSource() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(0, 1, 1, 1, 1);
graph1.addEdge(0, 2, 2, 1, 2);
graph1.addEdge(1, 3, 1, 1, 1);
graph1.addEdge(2, 3, 2, 1, 2);
graph1.setSinkNode(3);
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
assertEquals("Output flow for source null is null which doesn't match input flow 3 for sink 3",
exception.getMessage());
}
@Test
public void testDisconnectedGraph() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(0, 1, 1, 1, 1);
graph1.addEdge(2, 3, 2, 1, 2);
graph1.setSourceNode(0);
graph1.setSinkNode(1);
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
assertEquals("Input flow for node 3 is 2 which doesn't match output flow null",
exception.getMessage());
}
@Test
public void testDisconnectedGraphCrossSourceSink() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(0, 1, 1, 1, 1);
graph1.addEdge(2, 3, 2, 1, 2);
graph1.setSourceNode(0);
graph1.setSinkNode(3);
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
assertEquals("Input flow for node 1 is 1 which doesn't match output flow null",
exception.getMessage());
}
@Test
public void testJustSourceSink() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(0, 1, 1, 1, 1);
graph1.setSourceNode(0);
graph1.setSinkNode(1);
graph1.solveMinCostFlow();
assertEquals(1, graph1.totalCost());
}
@Test
public void testMinCostFlow() {
// Original graph, flow from 0 to 1 and 2 to 3
Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0);
Graph<Integer>.Edge edge = edges.get(1);
assertEquals(1, edge.flow);
assertEquals(0, edge.residualFlow);
edge = edges.get(3);
assertEquals(0, edge.flow);
assertEquals(1, edge.residualFlow);
edges = graph.edges(2);
edge = edges.get(3);
assertEquals(1, edge.flow);
assertEquals(0, edge.residualFlow);
edge = edges.get(1);
assertEquals(0, edge.flow);
assertEquals(1, edge.residualFlow);
assertEquals(5, graph.totalCost());
graph.solveMinCostFlow();
assertEquals(2, graph.totalCost());
edges = graph.edges(0);
assertEquals(2, edges.size());
// No flow from 0 to 1
edge = edges.get(1);
assertEquals(0, edge.flow);
assertEquals(1, edge.residualFlow);
// Flow from 0 to 3 now
edge = edges.get(3);
assertEquals(1, edge.flow);
assertEquals(0, edge.residualFlow);
edges = graph.edges(2);
assertEquals(2, edges.size());
// No flow from 2 to 3
edge = edges.get(3);
assertEquals(0, edge.flow);
assertEquals(1, edge.residualFlow);
// Flow from 2 to 1 now
edge = edges.get(1);
assertEquals(1, edge.flow);
assertEquals(0, edge.residualFlow);
}
@Test
public void testMinCostDetectNodeNotInNegativeCycle() {
final Graph<Integer> graph1 = new Graph<>();
graph1.addEdge(5, 0, 1, 0, 1);
graph1.addEdge(5, 1, 1, 0, 1);
graph1.addEdge(0, 2, 1, 1, 0);
graph1.addEdge(0, 3, 1, 1, 0);
graph1.addEdge(0, 4, 1, 10, 1);
graph1.addEdge(1, 2, 1, 1, 0);
graph1.addEdge(1, 3, 1, 10, 1);
graph1.addEdge(1, 4, 1, 1, 0);
graph1.addEdge(2, 6, 0, 0, 0);
graph1.addEdge(3, 6, 1, 0, 1);
graph1.addEdge(4, 6, 1, 0, 1);
graph1.setSourceNode(5);
graph1.setSinkNode(6);
assertEquals(20, graph1.totalCost());
// In this graph, the node we found for negative cycle is 2. However 2 isn't in the negative
// cycle itself. Negative cycle is 1 -> 4 -> 0 -> 3 -> 1
graph1.solveMinCostFlow();
assertEquals(2, graph1.totalCost());
Map<Integer, Graph<Integer>.Edge> edges = graph1.edges(5);
assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0));
assertEquals(getEdge(1, 1, 0, 0, 1), edges.get(1));
edges = graph1.edges(0);
assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2));
assertEquals(getEdge(3, 1, 1, 0, 1), edges.get(3));
assertEquals(getEdge(4, 1, 10, 1, 0), edges.get(4));
edges = graph1.edges(1);
assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2));
assertEquals(getEdge(3, 1, 10, 1, 0), edges.get(3));
assertEquals(getEdge(4, 1, 1, 0, 1), edges.get(4));
edges = graph1.edges(2);
assertEquals(getEdge(6, 0, 0, 0, 0), edges.get(6));
edges = graph1.edges(3);
assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6));
edges = graph1.edges(4);
assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6));
}
@Test
public void testDeterministic() {
final List<TestEdge> edgeList = new ArrayList<>();
edgeList.add(new TestEdge(0, 1, 1, 2, 1));
edgeList.add(new TestEdge(0, 2, 1, 1, 0));
edgeList.add(new TestEdge(0, 3, 1, 1, 0));
edgeList.add(new TestEdge(0, 4, 1, 1, 0));
edgeList.add(new TestEdge(1, 5, 1, 1, 1));
edgeList.add(new TestEdge(2, 5, 1, 1, 0));
edgeList.add(new TestEdge(3, 5, 1, 1, 0));
edgeList.add(new TestEdge(4, 5, 1, 1, 0));
// Test no matter the order of adding edges, min cost flow flows from 0 to 2 and then from 2 to 5
for (int i = 0; i < 10; i++) {
Collections.shuffle(edgeList);
final Graph<Integer> graph1 = new Graph<>();
for (final TestEdge edge : edgeList) {
graph1.addEdge(edge.source, edge.destination, edge.capacity, edge.cost, edge.flow);
}
graph1.setSourceNode(0);
graph1.setSinkNode(5);
assertEquals(3, graph1.totalCost());
graph1.solveMinCostFlow();
assertEquals(2, graph1.totalCost());
Map<Integer, Graph<Integer>.Edge> edges = graph1.edges(0);
assertEquals(4, edges.size());
assertEquals(getEdge(1, 1, 2, 1, 0), edges.get(1));
assertEquals(getEdge(2, 1, 1, 0, 1), edges.get(2));
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
assertEquals(getEdge(4, 1, 1, 1, 0), edges.get(4));
edges = graph1.edges(1);
assertEquals(1, edges.size());
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
edges = graph1.edges(2);
assertEquals(1, edges.size());
assertEquals(getEdge(5, 1, 1, 0, 1), edges.get(5));
edges = graph1.edges(3);
assertEquals(1, edges.size());
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
edges = graph1.edges(4);
assertEquals(1, edges.size());
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
}
}
private static Graph<Integer>.Edge getEdge(final int destination, final int capacity, final int cost, final int residualFlow, final int flow) {
return getEdge(destination, capacity, cost, residualFlow, flow, true);
}
private static Graph<Integer>.Edge getEdge(final int destination, final int capacity, final int cost, final int residualFlow, final int flow, final boolean forwardEdge) {
return new Graph<Integer>().new Edge(destination, capacity, cost, residualFlow, flow, forwardEdge);
}
private static class TestEdge {
final int source;
final int destination;
final int capacity;
final int cost;
final int flow;
TestEdge(final int source, final int destination, final int capacity, final int cost, final int flow) {
this.source = source;
this.destination = destination;
this.capacity = capacity;
this.cost = cost;
this.flow = flow;
}
}
}

View File

@ -40,7 +40,6 @@ import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.streams.StreamsConfig; import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.StreamsConfig.InternalConfig; import org.apache.kafka.streams.StreamsConfig.InternalConfig;
import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.RackAwareTaskAssignor;
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.test.MockClientSupplier; import org.apache.kafka.test.MockClientSupplier;
import org.apache.kafka.test.MockInternalTopicManager; import org.apache.kafka.test.MockInternalTopicManager;