mirror of https://github.com/apache/kafka.git
KAFKA-15022: [2/N] introduce graph to compute min cost (#13996)
Part of KIP-925. Reviewers: Matthias J. Sax <matthias@confluent.io>
This commit is contained in:
parent
125dbb9286
commit
6bb88ae2f7
|
@ -0,0 +1,377 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.kafka.streams.processor.internals.assignment;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.SortedMap;
|
||||
import java.util.SortedSet;
|
||||
import java.util.TreeMap;
|
||||
import java.util.TreeSet;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class Graph<V extends Comparable<V>> {
|
||||
public class Edge {
|
||||
final V destination;
|
||||
final int capacity;
|
||||
final int cost;
|
||||
int residualFlow;
|
||||
int flow;
|
||||
Edge counterEdge;
|
||||
boolean forwardEdge;
|
||||
|
||||
public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) {
|
||||
this(destination, capacity, cost, residualFlow, flow, true);
|
||||
}
|
||||
|
||||
public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow,
|
||||
final boolean forwardEdge) {
|
||||
Objects.requireNonNull(destination);
|
||||
if (capacity < 0) {
|
||||
throw new IllegalArgumentException("Edge capacity cannot be negative");
|
||||
}
|
||||
if (flow > capacity) {
|
||||
throw new IllegalArgumentException(String.format("Edge flow %d cannot exceed capacity %d",
|
||||
flow, capacity));
|
||||
}
|
||||
|
||||
this.destination = destination;
|
||||
this.capacity = capacity;
|
||||
this.cost = cost;
|
||||
this.residualFlow = residualFlow;
|
||||
this.flow = flow;
|
||||
this.forwardEdge = forwardEdge;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(final Object other) {
|
||||
if (this == other) {
|
||||
return true;
|
||||
}
|
||||
if (other == null || other.getClass() != getClass()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other;
|
||||
|
||||
return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity
|
||||
&& cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow
|
||||
&& forwardEdge == otherEdge.forwardEdge;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(destination, capacity, cost, residualFlow, flow, forwardEdge);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "{destination= " + destination + ", capacity=" + capacity + ", cost=" + cost
|
||||
+ ", residualFlow=" + residualFlow + ", flow=" + flow + ", forwardEdge=" + forwardEdge;
|
||||
}
|
||||
}
|
||||
|
||||
private final SortedMap<V, SortedMap<V, Edge>> adjList = new TreeMap<>();
|
||||
private final SortedSet<V> nodes = new TreeSet<>();
|
||||
private final boolean isResidualGraph;
|
||||
private V sourceNode, sinkNode;
|
||||
|
||||
public Graph() {
|
||||
this(false);
|
||||
}
|
||||
|
||||
private Graph(final boolean isResidualGraph) {
|
||||
this.isResidualGraph = isResidualGraph;
|
||||
}
|
||||
|
||||
public void addEdge(final V u, final V v, final int capacity, final int cost, final int flow) {
|
||||
addEdge(u, new Edge(v, capacity, cost, capacity - flow, flow));
|
||||
}
|
||||
|
||||
public Set<V> nodes() {
|
||||
return nodes;
|
||||
}
|
||||
|
||||
public Map<V, Edge> edges(final V node) {
|
||||
return adjList.get(node);
|
||||
}
|
||||
|
||||
public boolean isResidualGraph() {
|
||||
return isResidualGraph;
|
||||
}
|
||||
|
||||
public void setSourceNode(final V node) {
|
||||
sourceNode = node;
|
||||
}
|
||||
|
||||
public void setSinkNode(final V node) {
|
||||
sinkNode = node;
|
||||
}
|
||||
|
||||
public int totalCost() {
|
||||
int totalCost = 0;
|
||||
for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||
final SortedMap<V, Edge> edges = nodeEdges.getValue();
|
||||
for (final Entry<V, Edge> nodeEdge : edges.entrySet()) {
|
||||
totalCost += nodeEdge.getValue().cost * nodeEdge.getValue().flow;
|
||||
}
|
||||
}
|
||||
return totalCost;
|
||||
}
|
||||
|
||||
private void addEdge(final V u, final Edge edge) {
|
||||
if (!isResidualGraph) {
|
||||
// Check if there's already an edge from u to v
|
||||
final Map<V, Edge> edgeMap = adjList.get(edge.destination);
|
||||
if (edgeMap != null && edgeMap.containsKey(u)) {
|
||||
throw new IllegalArgumentException(
|
||||
"There is already an edge from " + edge.destination
|
||||
+ " to " + u + ". Can not add an edge from " + u + " to " + edge.destination
|
||||
+ " since there will create a cycle between two nodes");
|
||||
}
|
||||
}
|
||||
|
||||
adjList.computeIfAbsent(u, set -> new TreeMap<>()).put(edge.destination, edge);
|
||||
nodes.add(u);
|
||||
nodes.add(edge.destination);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get residual graph of this graph.
|
||||
* Residual graph definition:
|
||||
* If there is an edge in original graph from u to v with capacity c, cost w and flow f,
|
||||
* then in the new graph there are two edges e1 and e2. e1 is from u to v with capacity c - f,
|
||||
* cost w and flow f. e2 is from v to u with capacity f, cost -w and flow 0.
|
||||
*
|
||||
* @return Residual graph
|
||||
*/
|
||||
public Graph<V> residualGraph() {
|
||||
if (isResidualGraph) {
|
||||
return this;
|
||||
}
|
||||
|
||||
final Graph<V> residualGraph = new Graph<>(true);
|
||||
for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||
final V node = nodeEdges.getKey();
|
||||
final SortedMap<V, Edge> edges = nodeEdges.getValue();
|
||||
for (final Entry<V, Edge> nodeEdge : edges.entrySet()) {
|
||||
final Edge edge = nodeEdge.getValue();
|
||||
final Edge forwardEdge = new Edge(edge.destination, edge.capacity, edge.cost, edge.capacity - edge.flow, edge.flow);
|
||||
final Edge backwardEdge = new Edge(node, edge.capacity, edge.cost * -1, edge.flow, 0, false);
|
||||
forwardEdge.counterEdge = backwardEdge;
|
||||
backwardEdge.counterEdge = forwardEdge;
|
||||
residualGraph.addEdge(node, forwardEdge);
|
||||
residualGraph.addEdge(edge.destination, backwardEdge);
|
||||
}
|
||||
}
|
||||
return residualGraph;
|
||||
}
|
||||
|
||||
/**
|
||||
* Solve min cost flow with cycle canceling algorithm.
|
||||
*/
|
||||
public void solveMinCostFlow() {
|
||||
validateMinCostGraph();
|
||||
final Graph<V> residualGraph = residualGraph();
|
||||
residualGraph.cancelNegativeCycles();
|
||||
|
||||
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||
final V node = nodeEdges.getKey();
|
||||
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
|
||||
final V destination = nodeEdge.getKey();
|
||||
final Edge edge = nodeEdge.getValue();
|
||||
final Edge residualEdge = residualGraph.adjList.get(node).get(destination);
|
||||
edge.flow = residualEdge.flow;
|
||||
edge.residualFlow = residualEdge.residualFlow;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void populateInOutFlow(final Map<V, Long> inFlow, final Map<V, Long> outFlow) {
|
||||
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||
final V node = nodeEdges.getKey();
|
||||
if (node.equals(sinkNode)) {
|
||||
throw new IllegalStateException("Sink node " + sinkNode + " shouldn't have output");
|
||||
}
|
||||
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
|
||||
final V destination = nodeEdge.getKey();
|
||||
if (destination.equals(sourceNode)) {
|
||||
throw new IllegalStateException("Source node " + sourceNode + " shouldn't have input " + node);
|
||||
}
|
||||
final Edge edge = nodeEdge.getValue();
|
||||
Long count = outFlow.get(node);
|
||||
if (count == null) {
|
||||
outFlow.put(node, (long) edge.flow);
|
||||
} else {
|
||||
outFlow.put(node, count + edge.flow);
|
||||
}
|
||||
|
||||
count = inFlow.get(destination);
|
||||
if (count == null) {
|
||||
inFlow.put(destination, (long) edge.flow);
|
||||
} else {
|
||||
inFlow.put(destination, count + edge.flow);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void validateMinCostGraph() {
|
||||
if (isResidualGraph) {
|
||||
throw new IllegalStateException("Should not be residual graph to solve min cost flow");
|
||||
}
|
||||
|
||||
/*
|
||||
Check provided flow satisfying below constraints:
|
||||
1. Input flow and output flow for each node should be the same except for source and destination node
|
||||
2. Output flow of source and input flow of destination should be the same
|
||||
*/
|
||||
|
||||
final Map<V, Long> inFlow = new HashMap<>();
|
||||
final Map<V, Long> outFlow = new HashMap<>();
|
||||
populateInOutFlow(inFlow, outFlow);
|
||||
|
||||
for (final Entry<V, Long> in : inFlow.entrySet()) {
|
||||
if (in.getKey().equals(sourceNode) || in.getKey().equals(sinkNode)) {
|
||||
continue;
|
||||
}
|
||||
final Long out = outFlow.get(in.getKey());
|
||||
if (!Objects.equals(in.getValue(), out)) {
|
||||
throw new IllegalStateException("Input flow for node " + in.getKey() + " is " +
|
||||
in.getValue() + " which doesn't match output flow " + out);
|
||||
}
|
||||
}
|
||||
|
||||
final Long sourceOutput = outFlow.get(sourceNode);
|
||||
final Long sinkInput = inFlow.get(sinkNode);
|
||||
if (!Objects.equals(sourceOutput, sinkInput)) {
|
||||
throw new IllegalStateException("Output flow for source " + sourceNode + " is " + sourceOutput
|
||||
+ " which doesn't match input flow " + sinkInput + " for sink " + sinkNode);
|
||||
}
|
||||
}
|
||||
|
||||
private void cancelNegativeCycles() {
|
||||
if (!isResidualGraph) {
|
||||
throw new IllegalStateException("Should be residual graph to cancel negative cycles");
|
||||
}
|
||||
boolean cyclePossible = true;
|
||||
while (cyclePossible) {
|
||||
cyclePossible = false;
|
||||
for (final V node : nodes) {
|
||||
final Map<V, V> parentNodes = new HashMap<>();
|
||||
final Map<V, Edge> parentEdges = new HashMap<>();
|
||||
final V possibleNodeInCycle = detectNegativeCycles(node, parentNodes, parentEdges);
|
||||
|
||||
if (possibleNodeInCycle == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
final Set<V> visited = new HashSet<>();
|
||||
V nodeInCycle = possibleNodeInCycle;
|
||||
while (!visited.contains(nodeInCycle)) {
|
||||
visited.add(nodeInCycle);
|
||||
nodeInCycle = parentNodes.get(nodeInCycle);
|
||||
}
|
||||
|
||||
cyclePossible = true;
|
||||
cancelNegativeCycle(nodeInCycle, parentNodes, parentEdges);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void cancelNegativeCycle(final V nodeInCycle, final Map<V, V> parentNodes, final Map<V, Edge> parentEdges) {
|
||||
// Start from parentNode since nodeInCyle is used as exit condition in below loops
|
||||
final V parentNode = parentNodes.get(nodeInCycle);
|
||||
Edge parentEdge = parentEdges.get(nodeInCycle);
|
||||
|
||||
// Find max possible negative flow
|
||||
int possibleFlow = parentEdge.residualFlow;
|
||||
for (V curNode = parentNode; curNode != nodeInCycle; curNode = parentNodes.get(curNode)) {
|
||||
parentEdge = parentEdges.get(curNode);
|
||||
possibleFlow = Math.min(possibleFlow, parentEdge.residualFlow);
|
||||
}
|
||||
|
||||
// Update graph by removing negative flow
|
||||
parentEdge = parentEdges.get(nodeInCycle);
|
||||
Edge counterEdge = parentEdge.counterEdge;
|
||||
parentEdge.residualFlow -= possibleFlow;
|
||||
if (parentEdge.forwardEdge) {
|
||||
parentEdge.flow += possibleFlow;
|
||||
}
|
||||
counterEdge.residualFlow += possibleFlow;
|
||||
if (counterEdge.forwardEdge && counterEdge.flow >= possibleFlow) {
|
||||
counterEdge.flow -= possibleFlow;
|
||||
}
|
||||
for (V curNode = parentNode; curNode != nodeInCycle; curNode = parentNodes.get(curNode)) {
|
||||
parentEdge = parentEdges.get(curNode);
|
||||
counterEdge = parentEdge.counterEdge;
|
||||
parentEdge.residualFlow -= possibleFlow;
|
||||
if (parentEdge.forwardEdge) {
|
||||
parentEdge.flow += possibleFlow;
|
||||
}
|
||||
counterEdge.residualFlow += possibleFlow;
|
||||
if (counterEdge.forwardEdge && counterEdge.flow >= possibleFlow) {
|
||||
counterEdge.flow -= possibleFlow;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect negative cycle using Bellman-ford shortest path algorithm.
|
||||
* @param source Source node
|
||||
* @param parentNodes Parent nodes to store negative cycle nodes
|
||||
* @param parentEdges Parent edges to store negative cycle edges
|
||||
*
|
||||
* @return One node which can lead to negative cycle if exists or null if there's no negative cycle
|
||||
*/
|
||||
V detectNegativeCycles(final V source, final Map<V, V> parentNodes, final Map<V, Edge> parentEdges) {
|
||||
// Use long to account for any overflow
|
||||
final Map<V, Long> distance = nodes.stream().collect(Collectors.toMap(node -> node, node -> (long) Integer.MAX_VALUE));
|
||||
distance.put(source, 0L);
|
||||
final int nodeCount = nodes.size();
|
||||
|
||||
// Iterate nodeCount iterations since Bellaman-Ford will find shortest path in nodeCount - 1
|
||||
// iterations. If the distance can still be relaxed in nodeCount iteration, there's a negative
|
||||
// cycle
|
||||
for (int i = 0; i < nodeCount; i++) {
|
||||
// Iterate through all edges
|
||||
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||
final V u = nodeEdges.getKey();
|
||||
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
|
||||
final Edge edge = nodeEdge.getValue();
|
||||
if (edge.residualFlow == 0) {
|
||||
continue;
|
||||
}
|
||||
final V v = edge.destination;
|
||||
if (distance.get(v) > distance.get(u) + edge.cost) {
|
||||
if (i == nodeCount - 1) {
|
||||
return v;
|
||||
}
|
||||
distance.put(v, distance.get(u) + edge.cost);
|
||||
parentNodes.put(v, u);
|
||||
parentEdges.put(v, edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -14,7 +14,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.kafka.streams.processor.internals;
|
||||
package org.apache.kafka.streams.processor.internals.assignment;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
|
@ -30,6 +30,7 @@ import org.apache.kafka.common.TopicPartition;
|
|||
import org.apache.kafka.common.TopicPartitionInfo;
|
||||
import org.apache.kafka.streams.KeyValue;
|
||||
import org.apache.kafka.streams.processor.TaskId;
|
||||
import org.apache.kafka.streams.processor.internals.InternalTopicManager;
|
||||
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
|
||||
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
|
||||
import org.slf4j.Logger;
|
|
@ -0,0 +1,414 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.kafka.streams.processor.internals.assignment;
|
||||
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertSame;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
public class GraphTest {
|
||||
private Graph<Integer> graph;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
/*
|
||||
* Node 0 and 2 are both connected to node 1 and 3. There's a flow of 1 unit from 0 to 1 and 2 to
|
||||
* 3. The total cost in this case is 5. Min cost should be 2 by flowing 1 unit from 0 to 3 and 2
|
||||
* to 1
|
||||
*/
|
||||
graph = new Graph<>();
|
||||
graph.addEdge(0, 1, 1, 3, 1);
|
||||
graph.addEdge(0, 3, 1, 1, 0);
|
||||
graph.addEdge(2, 1, 1, 1, 0);
|
||||
graph.addEdge(2, 3, 1, 2, 1);
|
||||
graph.addEdge(4, 0, 1, 0, 1);
|
||||
graph.addEdge(4, 2, 1, 0, 1);
|
||||
graph.addEdge(1, 5, 1, 0, 1);
|
||||
graph.addEdge(3, 5, 1, 0, 1);
|
||||
graph.setSourceNode(4);
|
||||
graph.setSinkNode(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasic() {
|
||||
final Set<Integer> nodes = graph.nodes();
|
||||
assertEquals(6, nodes.size());
|
||||
assertThat(nodes, contains(0, 1, 2, 3, 4, 5));
|
||||
|
||||
Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0);
|
||||
assertEquals(2, edges.size());
|
||||
assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1));
|
||||
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
|
||||
|
||||
edges = graph.edges(2);
|
||||
assertEquals(2, edges.size());
|
||||
assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1));
|
||||
assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3));
|
||||
|
||||
edges = graph.edges(1);
|
||||
assertEquals(1, edges.size());
|
||||
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||
|
||||
edges = graph.edges(3);
|
||||
assertEquals(1, edges.size());
|
||||
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||
|
||||
edges = graph.edges(4);
|
||||
assertEquals(2, edges.size());
|
||||
assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0));
|
||||
assertEquals(getEdge(2, 1, 0, 0, 1), edges.get(2));
|
||||
|
||||
edges = graph.edges(5);
|
||||
assertNull(edges);
|
||||
|
||||
assertFalse(graph.isResidualGraph());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResidualGraph() {
|
||||
final Graph<Integer> residualGraph = graph.residualGraph();
|
||||
final Graph<Integer> residualGraph1 = residualGraph.residualGraph();
|
||||
assertSame(residualGraph1, residualGraph);
|
||||
|
||||
final Set<Integer> nodes = residualGraph.nodes();
|
||||
assertEquals(6, nodes.size());
|
||||
assertThat(nodes, contains(0, 1, 2, 3, 4, 5));
|
||||
|
||||
Map<Integer, Graph<Integer>.Edge> edges = residualGraph.edges(0);
|
||||
assertEquals(3, edges.size());
|
||||
assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1));
|
||||
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
|
||||
assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4));
|
||||
|
||||
edges = residualGraph.edges(2);
|
||||
assertEquals(3, edges.size());
|
||||
assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1));
|
||||
assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3));
|
||||
assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4));
|
||||
|
||||
edges = residualGraph.edges(1);
|
||||
assertEquals(3, edges.size());
|
||||
assertEquals(getEdge(0, 1, -3, 1, 0, false), edges.get(0));
|
||||
assertEquals(getEdge(2, 1, -1, 0, 0, false), edges.get(2));
|
||||
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||
|
||||
edges = residualGraph.edges(3);
|
||||
assertEquals(3, edges.size());
|
||||
assertEquals(getEdge(0, 1, -1, 0, 0, false), edges.get(0));
|
||||
assertEquals(getEdge(2, 1, -2, 1, 0, false), edges.get(2));
|
||||
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||
|
||||
assertTrue(residualGraph.isResidualGraph());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidOperation() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
Exception exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, -1, 0, 0));
|
||||
assertEquals("Edge capacity cannot be negative", exception.getMessage());
|
||||
|
||||
exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, 1, 0, 2));
|
||||
assertEquals("Edge flow 2 cannot exceed capacity 1", exception.getMessage());
|
||||
|
||||
graph1.addEdge(0, 1, 1, 1, 1);
|
||||
exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(1, 0, 1, 0, 0));
|
||||
assertEquals("There is already an edge from 0 to 1. Can not add an edge from 1 to 0 since "
|
||||
+ "there will create a cycle between two nodes", exception.getMessage());
|
||||
|
||||
final Graph<Integer> residualGraph = graph1.residualGraph();
|
||||
exception = assertThrows(IllegalStateException.class, residualGraph::solveMinCostFlow);
|
||||
assertEquals("Should not be residual graph to solve min cost flow", exception.getMessage());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidSource() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
graph1.addEdge(0, 1, 1, 1, 0);
|
||||
graph1.addEdge(1, 2, 1, 1, 0);
|
||||
graph1.setSourceNode(1);
|
||||
graph1.setSinkNode(2);
|
||||
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||
assertEquals("Source node 1 shouldn't have input 0", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidSink() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
graph1.addEdge(0, 1, 1, 1, 0);
|
||||
graph1.addEdge(1, 2, 1, 1, 0);
|
||||
graph1.setSourceNode(0);
|
||||
graph1.setSinkNode(1);
|
||||
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||
assertEquals("Sink node 1 shouldn't have output", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidFlow() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
graph1.addEdge(0, 1, 1, 1, 1);
|
||||
graph1.addEdge(0, 2, 2, 1, 2);
|
||||
graph1.addEdge(1, 3, 1, 1, 1);
|
||||
graph1.addEdge(2, 3, 2, 1, 0); // Missing flow from 2 to 3
|
||||
graph1.setSourceNode(0);
|
||||
graph1.setSinkNode(3);
|
||||
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||
assertEquals("Input flow for node 2 is 2 which doesn't match output flow 0", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingSource() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
graph1.addEdge(0, 1, 1, 1, 1);
|
||||
graph1.addEdge(0, 2, 2, 1, 2);
|
||||
graph1.addEdge(1, 3, 1, 1, 1);
|
||||
graph1.addEdge(2, 3, 2, 1, 2);
|
||||
graph1.setSinkNode(3);
|
||||
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||
assertEquals("Output flow for source null is null which doesn't match input flow 3 for sink 3",
|
||||
exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDisconnectedGraph() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
graph1.addEdge(0, 1, 1, 1, 1);
|
||||
graph1.addEdge(2, 3, 2, 1, 2);
|
||||
graph1.setSourceNode(0);
|
||||
graph1.setSinkNode(1);
|
||||
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||
assertEquals("Input flow for node 3 is 2 which doesn't match output flow null",
|
||||
exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDisconnectedGraphCrossSourceSink() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
graph1.addEdge(0, 1, 1, 1, 1);
|
||||
graph1.addEdge(2, 3, 2, 1, 2);
|
||||
graph1.setSourceNode(0);
|
||||
graph1.setSinkNode(3);
|
||||
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||
assertEquals("Input flow for node 1 is 1 which doesn't match output flow null",
|
||||
exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJustSourceSink() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
graph1.addEdge(0, 1, 1, 1, 1);
|
||||
graph1.setSourceNode(0);
|
||||
graph1.setSinkNode(1);
|
||||
graph1.solveMinCostFlow();
|
||||
assertEquals(1, graph1.totalCost());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMinCostFlow() {
|
||||
// Original graph, flow from 0 to 1 and 2 to 3
|
||||
Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0);
|
||||
Graph<Integer>.Edge edge = edges.get(1);
|
||||
assertEquals(1, edge.flow);
|
||||
assertEquals(0, edge.residualFlow);
|
||||
|
||||
edge = edges.get(3);
|
||||
assertEquals(0, edge.flow);
|
||||
assertEquals(1, edge.residualFlow);
|
||||
|
||||
edges = graph.edges(2);
|
||||
edge = edges.get(3);
|
||||
assertEquals(1, edge.flow);
|
||||
assertEquals(0, edge.residualFlow);
|
||||
|
||||
edge = edges.get(1);
|
||||
assertEquals(0, edge.flow);
|
||||
assertEquals(1, edge.residualFlow);
|
||||
|
||||
assertEquals(5, graph.totalCost());
|
||||
|
||||
graph.solveMinCostFlow();
|
||||
|
||||
assertEquals(2, graph.totalCost());
|
||||
|
||||
edges = graph.edges(0);
|
||||
assertEquals(2, edges.size());
|
||||
|
||||
// No flow from 0 to 1
|
||||
edge = edges.get(1);
|
||||
assertEquals(0, edge.flow);
|
||||
assertEquals(1, edge.residualFlow);
|
||||
|
||||
// Flow from 0 to 3 now
|
||||
edge = edges.get(3);
|
||||
assertEquals(1, edge.flow);
|
||||
assertEquals(0, edge.residualFlow);
|
||||
|
||||
edges = graph.edges(2);
|
||||
assertEquals(2, edges.size());
|
||||
|
||||
// No flow from 2 to 3
|
||||
edge = edges.get(3);
|
||||
assertEquals(0, edge.flow);
|
||||
assertEquals(1, edge.residualFlow);
|
||||
|
||||
// Flow from 2 to 1 now
|
||||
edge = edges.get(1);
|
||||
assertEquals(1, edge.flow);
|
||||
assertEquals(0, edge.residualFlow);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMinCostDetectNodeNotInNegativeCycle() {
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
|
||||
graph1.addEdge(5, 0, 1, 0, 1);
|
||||
graph1.addEdge(5, 1, 1, 0, 1);
|
||||
|
||||
graph1.addEdge(0, 2, 1, 1, 0);
|
||||
graph1.addEdge(0, 3, 1, 1, 0);
|
||||
graph1.addEdge(0, 4, 1, 10, 1);
|
||||
|
||||
graph1.addEdge(1, 2, 1, 1, 0);
|
||||
graph1.addEdge(1, 3, 1, 10, 1);
|
||||
graph1.addEdge(1, 4, 1, 1, 0);
|
||||
|
||||
graph1.addEdge(2, 6, 0, 0, 0);
|
||||
graph1.addEdge(3, 6, 1, 0, 1);
|
||||
graph1.addEdge(4, 6, 1, 0, 1);
|
||||
|
||||
graph1.setSourceNode(5);
|
||||
graph1.setSinkNode(6);
|
||||
|
||||
assertEquals(20, graph1.totalCost());
|
||||
|
||||
// In this graph, the node we found for negative cycle is 2. However 2 isn't in the negative
|
||||
// cycle itself. Negative cycle is 1 -> 4 -> 0 -> 3 -> 1
|
||||
graph1.solveMinCostFlow();
|
||||
assertEquals(2, graph1.totalCost());
|
||||
|
||||
Map<Integer, Graph<Integer>.Edge> edges = graph1.edges(5);
|
||||
assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0));
|
||||
assertEquals(getEdge(1, 1, 0, 0, 1), edges.get(1));
|
||||
|
||||
edges = graph1.edges(0);
|
||||
assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2));
|
||||
assertEquals(getEdge(3, 1, 1, 0, 1), edges.get(3));
|
||||
assertEquals(getEdge(4, 1, 10, 1, 0), edges.get(4));
|
||||
|
||||
edges = graph1.edges(1);
|
||||
assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2));
|
||||
assertEquals(getEdge(3, 1, 10, 1, 0), edges.get(3));
|
||||
assertEquals(getEdge(4, 1, 1, 0, 1), edges.get(4));
|
||||
|
||||
edges = graph1.edges(2);
|
||||
assertEquals(getEdge(6, 0, 0, 0, 0), edges.get(6));
|
||||
|
||||
edges = graph1.edges(3);
|
||||
assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6));
|
||||
|
||||
edges = graph1.edges(4);
|
||||
assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeterministic() {
|
||||
final List<TestEdge> edgeList = new ArrayList<>();
|
||||
edgeList.add(new TestEdge(0, 1, 1, 2, 1));
|
||||
edgeList.add(new TestEdge(0, 2, 1, 1, 0));
|
||||
edgeList.add(new TestEdge(0, 3, 1, 1, 0));
|
||||
edgeList.add(new TestEdge(0, 4, 1, 1, 0));
|
||||
edgeList.add(new TestEdge(1, 5, 1, 1, 1));
|
||||
edgeList.add(new TestEdge(2, 5, 1, 1, 0));
|
||||
edgeList.add(new TestEdge(3, 5, 1, 1, 0));
|
||||
edgeList.add(new TestEdge(4, 5, 1, 1, 0));
|
||||
|
||||
// Test no matter the order of adding edges, min cost flow flows from 0 to 2 and then from 2 to 5
|
||||
for (int i = 0; i < 10; i++) {
|
||||
Collections.shuffle(edgeList);
|
||||
final Graph<Integer> graph1 = new Graph<>();
|
||||
for (final TestEdge edge : edgeList) {
|
||||
graph1.addEdge(edge.source, edge.destination, edge.capacity, edge.cost, edge.flow);
|
||||
}
|
||||
graph1.setSourceNode(0);
|
||||
graph1.setSinkNode(5);
|
||||
assertEquals(3, graph1.totalCost());
|
||||
|
||||
graph1.solveMinCostFlow();
|
||||
assertEquals(2, graph1.totalCost());
|
||||
|
||||
Map<Integer, Graph<Integer>.Edge> edges = graph1.edges(0);
|
||||
assertEquals(4, edges.size());
|
||||
assertEquals(getEdge(1, 1, 2, 1, 0), edges.get(1));
|
||||
assertEquals(getEdge(2, 1, 1, 0, 1), edges.get(2));
|
||||
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
|
||||
assertEquals(getEdge(4, 1, 1, 1, 0), edges.get(4));
|
||||
|
||||
edges = graph1.edges(1);
|
||||
assertEquals(1, edges.size());
|
||||
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
|
||||
|
||||
edges = graph1.edges(2);
|
||||
assertEquals(1, edges.size());
|
||||
assertEquals(getEdge(5, 1, 1, 0, 1), edges.get(5));
|
||||
|
||||
edges = graph1.edges(3);
|
||||
assertEquals(1, edges.size());
|
||||
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
|
||||
|
||||
edges = graph1.edges(4);
|
||||
assertEquals(1, edges.size());
|
||||
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
|
||||
}
|
||||
}
|
||||
|
||||
private static Graph<Integer>.Edge getEdge(final int destination, final int capacity, final int cost, final int residualFlow, final int flow) {
|
||||
return getEdge(destination, capacity, cost, residualFlow, flow, true);
|
||||
}
|
||||
|
||||
private static Graph<Integer>.Edge getEdge(final int destination, final int capacity, final int cost, final int residualFlow, final int flow, final boolean forwardEdge) {
|
||||
return new Graph<Integer>().new Edge(destination, capacity, cost, residualFlow, flow, forwardEdge);
|
||||
}
|
||||
|
||||
private static class TestEdge {
|
||||
final int source;
|
||||
final int destination;
|
||||
final int capacity;
|
||||
final int cost;
|
||||
final int flow;
|
||||
|
||||
TestEdge(final int source, final int destination, final int capacity, final int cost, final int flow) {
|
||||
this.source = source;
|
||||
this.destination = destination;
|
||||
this.capacity = capacity;
|
||||
this.cost = cost;
|
||||
this.flow = flow;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -40,7 +40,6 @@ import org.apache.kafka.common.utils.MockTime;
|
|||
import org.apache.kafka.streams.StreamsConfig;
|
||||
import org.apache.kafka.streams.StreamsConfig.InternalConfig;
|
||||
import org.apache.kafka.streams.processor.TaskId;
|
||||
import org.apache.kafka.streams.processor.internals.RackAwareTaskAssignor;
|
||||
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
|
||||
import org.apache.kafka.test.MockClientSupplier;
|
||||
import org.apache.kafka.test.MockInternalTopicManager;
|
||||
|
|
Loading…
Reference in New Issue