mirror of https://github.com/apache/kafka.git
KAFKA-15022: [2/N] introduce graph to compute min cost (#13996)
Part of KIP-925. Reviewers: Matthias J. Sax <matthias@confluent.io>
This commit is contained in:
parent
125dbb9286
commit
6bb88ae2f7
|
@ -0,0 +1,377 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.kafka.streams.processor.internals.assignment;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.SortedMap;
|
||||||
|
import java.util.SortedSet;
|
||||||
|
import java.util.TreeMap;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class Graph<V extends Comparable<V>> {
|
||||||
|
public class Edge {
|
||||||
|
final V destination;
|
||||||
|
final int capacity;
|
||||||
|
final int cost;
|
||||||
|
int residualFlow;
|
||||||
|
int flow;
|
||||||
|
Edge counterEdge;
|
||||||
|
boolean forwardEdge;
|
||||||
|
|
||||||
|
public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) {
|
||||||
|
this(destination, capacity, cost, residualFlow, flow, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow,
|
||||||
|
final boolean forwardEdge) {
|
||||||
|
Objects.requireNonNull(destination);
|
||||||
|
if (capacity < 0) {
|
||||||
|
throw new IllegalArgumentException("Edge capacity cannot be negative");
|
||||||
|
}
|
||||||
|
if (flow > capacity) {
|
||||||
|
throw new IllegalArgumentException(String.format("Edge flow %d cannot exceed capacity %d",
|
||||||
|
flow, capacity));
|
||||||
|
}
|
||||||
|
|
||||||
|
this.destination = destination;
|
||||||
|
this.capacity = capacity;
|
||||||
|
this.cost = cost;
|
||||||
|
this.residualFlow = residualFlow;
|
||||||
|
this.flow = flow;
|
||||||
|
this.forwardEdge = forwardEdge;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(final Object other) {
|
||||||
|
if (this == other) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (other == null || other.getClass() != getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other;
|
||||||
|
|
||||||
|
return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity
|
||||||
|
&& cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow
|
||||||
|
&& forwardEdge == otherEdge.forwardEdge;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(destination, capacity, cost, residualFlow, flow, forwardEdge);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "{destination= " + destination + ", capacity=" + capacity + ", cost=" + cost
|
||||||
|
+ ", residualFlow=" + residualFlow + ", flow=" + flow + ", forwardEdge=" + forwardEdge;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final SortedMap<V, SortedMap<V, Edge>> adjList = new TreeMap<>();
|
||||||
|
private final SortedSet<V> nodes = new TreeSet<>();
|
||||||
|
private final boolean isResidualGraph;
|
||||||
|
private V sourceNode, sinkNode;
|
||||||
|
|
||||||
|
public Graph() {
|
||||||
|
this(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Graph(final boolean isResidualGraph) {
|
||||||
|
this.isResidualGraph = isResidualGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addEdge(final V u, final V v, final int capacity, final int cost, final int flow) {
|
||||||
|
addEdge(u, new Edge(v, capacity, cost, capacity - flow, flow));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<V> nodes() {
|
||||||
|
return nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<V, Edge> edges(final V node) {
|
||||||
|
return adjList.get(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isResidualGraph() {
|
||||||
|
return isResidualGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSourceNode(final V node) {
|
||||||
|
sourceNode = node;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSinkNode(final V node) {
|
||||||
|
sinkNode = node;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int totalCost() {
|
||||||
|
int totalCost = 0;
|
||||||
|
for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||||
|
final SortedMap<V, Edge> edges = nodeEdges.getValue();
|
||||||
|
for (final Entry<V, Edge> nodeEdge : edges.entrySet()) {
|
||||||
|
totalCost += nodeEdge.getValue().cost * nodeEdge.getValue().flow;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return totalCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addEdge(final V u, final Edge edge) {
|
||||||
|
if (!isResidualGraph) {
|
||||||
|
// Check if there's already an edge from u to v
|
||||||
|
final Map<V, Edge> edgeMap = adjList.get(edge.destination);
|
||||||
|
if (edgeMap != null && edgeMap.containsKey(u)) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"There is already an edge from " + edge.destination
|
||||||
|
+ " to " + u + ". Can not add an edge from " + u + " to " + edge.destination
|
||||||
|
+ " since there will create a cycle between two nodes");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
adjList.computeIfAbsent(u, set -> new TreeMap<>()).put(edge.destination, edge);
|
||||||
|
nodes.add(u);
|
||||||
|
nodes.add(edge.destination);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get residual graph of this graph.
|
||||||
|
* Residual graph definition:
|
||||||
|
* If there is an edge in original graph from u to v with capacity c, cost w and flow f,
|
||||||
|
* then in the new graph there are two edges e1 and e2. e1 is from u to v with capacity c - f,
|
||||||
|
* cost w and flow f. e2 is from v to u with capacity f, cost -w and flow 0.
|
||||||
|
*
|
||||||
|
* @return Residual graph
|
||||||
|
*/
|
||||||
|
public Graph<V> residualGraph() {
|
||||||
|
if (isResidualGraph) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
final Graph<V> residualGraph = new Graph<>(true);
|
||||||
|
for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||||
|
final V node = nodeEdges.getKey();
|
||||||
|
final SortedMap<V, Edge> edges = nodeEdges.getValue();
|
||||||
|
for (final Entry<V, Edge> nodeEdge : edges.entrySet()) {
|
||||||
|
final Edge edge = nodeEdge.getValue();
|
||||||
|
final Edge forwardEdge = new Edge(edge.destination, edge.capacity, edge.cost, edge.capacity - edge.flow, edge.flow);
|
||||||
|
final Edge backwardEdge = new Edge(node, edge.capacity, edge.cost * -1, edge.flow, 0, false);
|
||||||
|
forwardEdge.counterEdge = backwardEdge;
|
||||||
|
backwardEdge.counterEdge = forwardEdge;
|
||||||
|
residualGraph.addEdge(node, forwardEdge);
|
||||||
|
residualGraph.addEdge(edge.destination, backwardEdge);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return residualGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solve min cost flow with cycle canceling algorithm.
|
||||||
|
*/
|
||||||
|
public void solveMinCostFlow() {
|
||||||
|
validateMinCostGraph();
|
||||||
|
final Graph<V> residualGraph = residualGraph();
|
||||||
|
residualGraph.cancelNegativeCycles();
|
||||||
|
|
||||||
|
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||||
|
final V node = nodeEdges.getKey();
|
||||||
|
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
|
||||||
|
final V destination = nodeEdge.getKey();
|
||||||
|
final Edge edge = nodeEdge.getValue();
|
||||||
|
final Edge residualEdge = residualGraph.adjList.get(node).get(destination);
|
||||||
|
edge.flow = residualEdge.flow;
|
||||||
|
edge.residualFlow = residualEdge.residualFlow;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void populateInOutFlow(final Map<V, Long> inFlow, final Map<V, Long> outFlow) {
|
||||||
|
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||||
|
final V node = nodeEdges.getKey();
|
||||||
|
if (node.equals(sinkNode)) {
|
||||||
|
throw new IllegalStateException("Sink node " + sinkNode + " shouldn't have output");
|
||||||
|
}
|
||||||
|
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
|
||||||
|
final V destination = nodeEdge.getKey();
|
||||||
|
if (destination.equals(sourceNode)) {
|
||||||
|
throw new IllegalStateException("Source node " + sourceNode + " shouldn't have input " + node);
|
||||||
|
}
|
||||||
|
final Edge edge = nodeEdge.getValue();
|
||||||
|
Long count = outFlow.get(node);
|
||||||
|
if (count == null) {
|
||||||
|
outFlow.put(node, (long) edge.flow);
|
||||||
|
} else {
|
||||||
|
outFlow.put(node, count + edge.flow);
|
||||||
|
}
|
||||||
|
|
||||||
|
count = inFlow.get(destination);
|
||||||
|
if (count == null) {
|
||||||
|
inFlow.put(destination, (long) edge.flow);
|
||||||
|
} else {
|
||||||
|
inFlow.put(destination, count + edge.flow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void validateMinCostGraph() {
|
||||||
|
if (isResidualGraph) {
|
||||||
|
throw new IllegalStateException("Should not be residual graph to solve min cost flow");
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Check provided flow satisfying below constraints:
|
||||||
|
1. Input flow and output flow for each node should be the same except for source and destination node
|
||||||
|
2. Output flow of source and input flow of destination should be the same
|
||||||
|
*/
|
||||||
|
|
||||||
|
final Map<V, Long> inFlow = new HashMap<>();
|
||||||
|
final Map<V, Long> outFlow = new HashMap<>();
|
||||||
|
populateInOutFlow(inFlow, outFlow);
|
||||||
|
|
||||||
|
for (final Entry<V, Long> in : inFlow.entrySet()) {
|
||||||
|
if (in.getKey().equals(sourceNode) || in.getKey().equals(sinkNode)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
final Long out = outFlow.get(in.getKey());
|
||||||
|
if (!Objects.equals(in.getValue(), out)) {
|
||||||
|
throw new IllegalStateException("Input flow for node " + in.getKey() + " is " +
|
||||||
|
in.getValue() + " which doesn't match output flow " + out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final Long sourceOutput = outFlow.get(sourceNode);
|
||||||
|
final Long sinkInput = inFlow.get(sinkNode);
|
||||||
|
if (!Objects.equals(sourceOutput, sinkInput)) {
|
||||||
|
throw new IllegalStateException("Output flow for source " + sourceNode + " is " + sourceOutput
|
||||||
|
+ " which doesn't match input flow " + sinkInput + " for sink " + sinkNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void cancelNegativeCycles() {
|
||||||
|
if (!isResidualGraph) {
|
||||||
|
throw new IllegalStateException("Should be residual graph to cancel negative cycles");
|
||||||
|
}
|
||||||
|
boolean cyclePossible = true;
|
||||||
|
while (cyclePossible) {
|
||||||
|
cyclePossible = false;
|
||||||
|
for (final V node : nodes) {
|
||||||
|
final Map<V, V> parentNodes = new HashMap<>();
|
||||||
|
final Map<V, Edge> parentEdges = new HashMap<>();
|
||||||
|
final V possibleNodeInCycle = detectNegativeCycles(node, parentNodes, parentEdges);
|
||||||
|
|
||||||
|
if (possibleNodeInCycle == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
final Set<V> visited = new HashSet<>();
|
||||||
|
V nodeInCycle = possibleNodeInCycle;
|
||||||
|
while (!visited.contains(nodeInCycle)) {
|
||||||
|
visited.add(nodeInCycle);
|
||||||
|
nodeInCycle = parentNodes.get(nodeInCycle);
|
||||||
|
}
|
||||||
|
|
||||||
|
cyclePossible = true;
|
||||||
|
cancelNegativeCycle(nodeInCycle, parentNodes, parentEdges);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void cancelNegativeCycle(final V nodeInCycle, final Map<V, V> parentNodes, final Map<V, Edge> parentEdges) {
|
||||||
|
// Start from parentNode since nodeInCyle is used as exit condition in below loops
|
||||||
|
final V parentNode = parentNodes.get(nodeInCycle);
|
||||||
|
Edge parentEdge = parentEdges.get(nodeInCycle);
|
||||||
|
|
||||||
|
// Find max possible negative flow
|
||||||
|
int possibleFlow = parentEdge.residualFlow;
|
||||||
|
for (V curNode = parentNode; curNode != nodeInCycle; curNode = parentNodes.get(curNode)) {
|
||||||
|
parentEdge = parentEdges.get(curNode);
|
||||||
|
possibleFlow = Math.min(possibleFlow, parentEdge.residualFlow);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update graph by removing negative flow
|
||||||
|
parentEdge = parentEdges.get(nodeInCycle);
|
||||||
|
Edge counterEdge = parentEdge.counterEdge;
|
||||||
|
parentEdge.residualFlow -= possibleFlow;
|
||||||
|
if (parentEdge.forwardEdge) {
|
||||||
|
parentEdge.flow += possibleFlow;
|
||||||
|
}
|
||||||
|
counterEdge.residualFlow += possibleFlow;
|
||||||
|
if (counterEdge.forwardEdge && counterEdge.flow >= possibleFlow) {
|
||||||
|
counterEdge.flow -= possibleFlow;
|
||||||
|
}
|
||||||
|
for (V curNode = parentNode; curNode != nodeInCycle; curNode = parentNodes.get(curNode)) {
|
||||||
|
parentEdge = parentEdges.get(curNode);
|
||||||
|
counterEdge = parentEdge.counterEdge;
|
||||||
|
parentEdge.residualFlow -= possibleFlow;
|
||||||
|
if (parentEdge.forwardEdge) {
|
||||||
|
parentEdge.flow += possibleFlow;
|
||||||
|
}
|
||||||
|
counterEdge.residualFlow += possibleFlow;
|
||||||
|
if (counterEdge.forwardEdge && counterEdge.flow >= possibleFlow) {
|
||||||
|
counterEdge.flow -= possibleFlow;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Detect negative cycle using Bellman-ford shortest path algorithm.
|
||||||
|
* @param source Source node
|
||||||
|
* @param parentNodes Parent nodes to store negative cycle nodes
|
||||||
|
* @param parentEdges Parent edges to store negative cycle edges
|
||||||
|
*
|
||||||
|
* @return One node which can lead to negative cycle if exists or null if there's no negative cycle
|
||||||
|
*/
|
||||||
|
V detectNegativeCycles(final V source, final Map<V, V> parentNodes, final Map<V, Edge> parentEdges) {
|
||||||
|
// Use long to account for any overflow
|
||||||
|
final Map<V, Long> distance = nodes.stream().collect(Collectors.toMap(node -> node, node -> (long) Integer.MAX_VALUE));
|
||||||
|
distance.put(source, 0L);
|
||||||
|
final int nodeCount = nodes.size();
|
||||||
|
|
||||||
|
// Iterate nodeCount iterations since Bellaman-Ford will find shortest path in nodeCount - 1
|
||||||
|
// iterations. If the distance can still be relaxed in nodeCount iteration, there's a negative
|
||||||
|
// cycle
|
||||||
|
for (int i = 0; i < nodeCount; i++) {
|
||||||
|
// Iterate through all edges
|
||||||
|
for (final Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) {
|
||||||
|
final V u = nodeEdges.getKey();
|
||||||
|
for (final Entry<V, Edge> nodeEdge : nodeEdges.getValue().entrySet()) {
|
||||||
|
final Edge edge = nodeEdge.getValue();
|
||||||
|
if (edge.residualFlow == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
final V v = edge.destination;
|
||||||
|
if (distance.get(v) > distance.get(u) + edge.cost) {
|
||||||
|
if (i == nodeCount - 1) {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
distance.put(v, distance.get(u) + edge.cost);
|
||||||
|
parentNodes.put(v, u);
|
||||||
|
parentEdges.put(v, edge);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,7 +14,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
package org.apache.kafka.streams.processor.internals;
|
package org.apache.kafka.streams.processor.internals.assignment;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
@ -30,6 +30,7 @@ import org.apache.kafka.common.TopicPartition;
|
||||||
import org.apache.kafka.common.TopicPartitionInfo;
|
import org.apache.kafka.common.TopicPartitionInfo;
|
||||||
import org.apache.kafka.streams.KeyValue;
|
import org.apache.kafka.streams.KeyValue;
|
||||||
import org.apache.kafka.streams.processor.TaskId;
|
import org.apache.kafka.streams.processor.TaskId;
|
||||||
|
import org.apache.kafka.streams.processor.internals.InternalTopicManager;
|
||||||
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
|
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
|
||||||
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
|
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
|
@ -0,0 +1,414 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.kafka.streams.processor.internals.assignment;
|
||||||
|
|
||||||
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
|
import static org.hamcrest.Matchers.contains;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertSame;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class GraphTest {
|
||||||
|
private Graph<Integer> graph;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
/*
|
||||||
|
* Node 0 and 2 are both connected to node 1 and 3. There's a flow of 1 unit from 0 to 1 and 2 to
|
||||||
|
* 3. The total cost in this case is 5. Min cost should be 2 by flowing 1 unit from 0 to 3 and 2
|
||||||
|
* to 1
|
||||||
|
*/
|
||||||
|
graph = new Graph<>();
|
||||||
|
graph.addEdge(0, 1, 1, 3, 1);
|
||||||
|
graph.addEdge(0, 3, 1, 1, 0);
|
||||||
|
graph.addEdge(2, 1, 1, 1, 0);
|
||||||
|
graph.addEdge(2, 3, 1, 2, 1);
|
||||||
|
graph.addEdge(4, 0, 1, 0, 1);
|
||||||
|
graph.addEdge(4, 2, 1, 0, 1);
|
||||||
|
graph.addEdge(1, 5, 1, 0, 1);
|
||||||
|
graph.addEdge(3, 5, 1, 0, 1);
|
||||||
|
graph.setSourceNode(4);
|
||||||
|
graph.setSinkNode(5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasic() {
|
||||||
|
final Set<Integer> nodes = graph.nodes();
|
||||||
|
assertEquals(6, nodes.size());
|
||||||
|
assertThat(nodes, contains(0, 1, 2, 3, 4, 5));
|
||||||
|
|
||||||
|
Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0);
|
||||||
|
assertEquals(2, edges.size());
|
||||||
|
assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1));
|
||||||
|
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
|
||||||
|
|
||||||
|
edges = graph.edges(2);
|
||||||
|
assertEquals(2, edges.size());
|
||||||
|
assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1));
|
||||||
|
assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3));
|
||||||
|
|
||||||
|
edges = graph.edges(1);
|
||||||
|
assertEquals(1, edges.size());
|
||||||
|
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||||
|
|
||||||
|
edges = graph.edges(3);
|
||||||
|
assertEquals(1, edges.size());
|
||||||
|
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||||
|
|
||||||
|
edges = graph.edges(4);
|
||||||
|
assertEquals(2, edges.size());
|
||||||
|
assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0));
|
||||||
|
assertEquals(getEdge(2, 1, 0, 0, 1), edges.get(2));
|
||||||
|
|
||||||
|
edges = graph.edges(5);
|
||||||
|
assertNull(edges);
|
||||||
|
|
||||||
|
assertFalse(graph.isResidualGraph());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testResidualGraph() {
|
||||||
|
final Graph<Integer> residualGraph = graph.residualGraph();
|
||||||
|
final Graph<Integer> residualGraph1 = residualGraph.residualGraph();
|
||||||
|
assertSame(residualGraph1, residualGraph);
|
||||||
|
|
||||||
|
final Set<Integer> nodes = residualGraph.nodes();
|
||||||
|
assertEquals(6, nodes.size());
|
||||||
|
assertThat(nodes, contains(0, 1, 2, 3, 4, 5));
|
||||||
|
|
||||||
|
Map<Integer, Graph<Integer>.Edge> edges = residualGraph.edges(0);
|
||||||
|
assertEquals(3, edges.size());
|
||||||
|
assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1));
|
||||||
|
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
|
||||||
|
assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4));
|
||||||
|
|
||||||
|
edges = residualGraph.edges(2);
|
||||||
|
assertEquals(3, edges.size());
|
||||||
|
assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1));
|
||||||
|
assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3));
|
||||||
|
assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4));
|
||||||
|
|
||||||
|
edges = residualGraph.edges(1);
|
||||||
|
assertEquals(3, edges.size());
|
||||||
|
assertEquals(getEdge(0, 1, -3, 1, 0, false), edges.get(0));
|
||||||
|
assertEquals(getEdge(2, 1, -1, 0, 0, false), edges.get(2));
|
||||||
|
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||||
|
|
||||||
|
edges = residualGraph.edges(3);
|
||||||
|
assertEquals(3, edges.size());
|
||||||
|
assertEquals(getEdge(0, 1, -1, 0, 0, false), edges.get(0));
|
||||||
|
assertEquals(getEdge(2, 1, -2, 1, 0, false), edges.get(2));
|
||||||
|
assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5));
|
||||||
|
|
||||||
|
assertTrue(residualGraph.isResidualGraph());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidOperation() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
Exception exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, -1, 0, 0));
|
||||||
|
assertEquals("Edge capacity cannot be negative", exception.getMessage());
|
||||||
|
|
||||||
|
exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, 1, 0, 2));
|
||||||
|
assertEquals("Edge flow 2 cannot exceed capacity 1", exception.getMessage());
|
||||||
|
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 1);
|
||||||
|
exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(1, 0, 1, 0, 0));
|
||||||
|
assertEquals("There is already an edge from 0 to 1. Can not add an edge from 1 to 0 since "
|
||||||
|
+ "there will create a cycle between two nodes", exception.getMessage());
|
||||||
|
|
||||||
|
final Graph<Integer> residualGraph = graph1.residualGraph();
|
||||||
|
exception = assertThrows(IllegalStateException.class, residualGraph::solveMinCostFlow);
|
||||||
|
assertEquals("Should not be residual graph to solve min cost flow", exception.getMessage());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidSource() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 0);
|
||||||
|
graph1.addEdge(1, 2, 1, 1, 0);
|
||||||
|
graph1.setSourceNode(1);
|
||||||
|
graph1.setSinkNode(2);
|
||||||
|
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||||
|
assertEquals("Source node 1 shouldn't have input 0", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidSink() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 0);
|
||||||
|
graph1.addEdge(1, 2, 1, 1, 0);
|
||||||
|
graph1.setSourceNode(0);
|
||||||
|
graph1.setSinkNode(1);
|
||||||
|
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||||
|
assertEquals("Sink node 1 shouldn't have output", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidFlow() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 1);
|
||||||
|
graph1.addEdge(0, 2, 2, 1, 2);
|
||||||
|
graph1.addEdge(1, 3, 1, 1, 1);
|
||||||
|
graph1.addEdge(2, 3, 2, 1, 0); // Missing flow from 2 to 3
|
||||||
|
graph1.setSourceNode(0);
|
||||||
|
graph1.setSinkNode(3);
|
||||||
|
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||||
|
assertEquals("Input flow for node 2 is 2 which doesn't match output flow 0", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMissingSource() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 1);
|
||||||
|
graph1.addEdge(0, 2, 2, 1, 2);
|
||||||
|
graph1.addEdge(1, 3, 1, 1, 1);
|
||||||
|
graph1.addEdge(2, 3, 2, 1, 2);
|
||||||
|
graph1.setSinkNode(3);
|
||||||
|
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||||
|
assertEquals("Output flow for source null is null which doesn't match input flow 3 for sink 3",
|
||||||
|
exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDisconnectedGraph() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 1);
|
||||||
|
graph1.addEdge(2, 3, 2, 1, 2);
|
||||||
|
graph1.setSourceNode(0);
|
||||||
|
graph1.setSinkNode(1);
|
||||||
|
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||||
|
assertEquals("Input flow for node 3 is 2 which doesn't match output flow null",
|
||||||
|
exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDisconnectedGraphCrossSourceSink() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 1);
|
||||||
|
graph1.addEdge(2, 3, 2, 1, 2);
|
||||||
|
graph1.setSourceNode(0);
|
||||||
|
graph1.setSinkNode(3);
|
||||||
|
final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow);
|
||||||
|
assertEquals("Input flow for node 1 is 1 which doesn't match output flow null",
|
||||||
|
exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testJustSourceSink() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
graph1.addEdge(0, 1, 1, 1, 1);
|
||||||
|
graph1.setSourceNode(0);
|
||||||
|
graph1.setSinkNode(1);
|
||||||
|
graph1.solveMinCostFlow();
|
||||||
|
assertEquals(1, graph1.totalCost());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMinCostFlow() {
|
||||||
|
// Original graph, flow from 0 to 1 and 2 to 3
|
||||||
|
Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0);
|
||||||
|
Graph<Integer>.Edge edge = edges.get(1);
|
||||||
|
assertEquals(1, edge.flow);
|
||||||
|
assertEquals(0, edge.residualFlow);
|
||||||
|
|
||||||
|
edge = edges.get(3);
|
||||||
|
assertEquals(0, edge.flow);
|
||||||
|
assertEquals(1, edge.residualFlow);
|
||||||
|
|
||||||
|
edges = graph.edges(2);
|
||||||
|
edge = edges.get(3);
|
||||||
|
assertEquals(1, edge.flow);
|
||||||
|
assertEquals(0, edge.residualFlow);
|
||||||
|
|
||||||
|
edge = edges.get(1);
|
||||||
|
assertEquals(0, edge.flow);
|
||||||
|
assertEquals(1, edge.residualFlow);
|
||||||
|
|
||||||
|
assertEquals(5, graph.totalCost());
|
||||||
|
|
||||||
|
graph.solveMinCostFlow();
|
||||||
|
|
||||||
|
assertEquals(2, graph.totalCost());
|
||||||
|
|
||||||
|
edges = graph.edges(0);
|
||||||
|
assertEquals(2, edges.size());
|
||||||
|
|
||||||
|
// No flow from 0 to 1
|
||||||
|
edge = edges.get(1);
|
||||||
|
assertEquals(0, edge.flow);
|
||||||
|
assertEquals(1, edge.residualFlow);
|
||||||
|
|
||||||
|
// Flow from 0 to 3 now
|
||||||
|
edge = edges.get(3);
|
||||||
|
assertEquals(1, edge.flow);
|
||||||
|
assertEquals(0, edge.residualFlow);
|
||||||
|
|
||||||
|
edges = graph.edges(2);
|
||||||
|
assertEquals(2, edges.size());
|
||||||
|
|
||||||
|
// No flow from 2 to 3
|
||||||
|
edge = edges.get(3);
|
||||||
|
assertEquals(0, edge.flow);
|
||||||
|
assertEquals(1, edge.residualFlow);
|
||||||
|
|
||||||
|
// Flow from 2 to 1 now
|
||||||
|
edge = edges.get(1);
|
||||||
|
assertEquals(1, edge.flow);
|
||||||
|
assertEquals(0, edge.residualFlow);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMinCostDetectNodeNotInNegativeCycle() {
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
|
||||||
|
graph1.addEdge(5, 0, 1, 0, 1);
|
||||||
|
graph1.addEdge(5, 1, 1, 0, 1);
|
||||||
|
|
||||||
|
graph1.addEdge(0, 2, 1, 1, 0);
|
||||||
|
graph1.addEdge(0, 3, 1, 1, 0);
|
||||||
|
graph1.addEdge(0, 4, 1, 10, 1);
|
||||||
|
|
||||||
|
graph1.addEdge(1, 2, 1, 1, 0);
|
||||||
|
graph1.addEdge(1, 3, 1, 10, 1);
|
||||||
|
graph1.addEdge(1, 4, 1, 1, 0);
|
||||||
|
|
||||||
|
graph1.addEdge(2, 6, 0, 0, 0);
|
||||||
|
graph1.addEdge(3, 6, 1, 0, 1);
|
||||||
|
graph1.addEdge(4, 6, 1, 0, 1);
|
||||||
|
|
||||||
|
graph1.setSourceNode(5);
|
||||||
|
graph1.setSinkNode(6);
|
||||||
|
|
||||||
|
assertEquals(20, graph1.totalCost());
|
||||||
|
|
||||||
|
// In this graph, the node we found for negative cycle is 2. However 2 isn't in the negative
|
||||||
|
// cycle itself. Negative cycle is 1 -> 4 -> 0 -> 3 -> 1
|
||||||
|
graph1.solveMinCostFlow();
|
||||||
|
assertEquals(2, graph1.totalCost());
|
||||||
|
|
||||||
|
Map<Integer, Graph<Integer>.Edge> edges = graph1.edges(5);
|
||||||
|
assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0));
|
||||||
|
assertEquals(getEdge(1, 1, 0, 0, 1), edges.get(1));
|
||||||
|
|
||||||
|
edges = graph1.edges(0);
|
||||||
|
assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2));
|
||||||
|
assertEquals(getEdge(3, 1, 1, 0, 1), edges.get(3));
|
||||||
|
assertEquals(getEdge(4, 1, 10, 1, 0), edges.get(4));
|
||||||
|
|
||||||
|
edges = graph1.edges(1);
|
||||||
|
assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2));
|
||||||
|
assertEquals(getEdge(3, 1, 10, 1, 0), edges.get(3));
|
||||||
|
assertEquals(getEdge(4, 1, 1, 0, 1), edges.get(4));
|
||||||
|
|
||||||
|
edges = graph1.edges(2);
|
||||||
|
assertEquals(getEdge(6, 0, 0, 0, 0), edges.get(6));
|
||||||
|
|
||||||
|
edges = graph1.edges(3);
|
||||||
|
assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6));
|
||||||
|
|
||||||
|
edges = graph1.edges(4);
|
||||||
|
assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDeterministic() {
|
||||||
|
final List<TestEdge> edgeList = new ArrayList<>();
|
||||||
|
edgeList.add(new TestEdge(0, 1, 1, 2, 1));
|
||||||
|
edgeList.add(new TestEdge(0, 2, 1, 1, 0));
|
||||||
|
edgeList.add(new TestEdge(0, 3, 1, 1, 0));
|
||||||
|
edgeList.add(new TestEdge(0, 4, 1, 1, 0));
|
||||||
|
edgeList.add(new TestEdge(1, 5, 1, 1, 1));
|
||||||
|
edgeList.add(new TestEdge(2, 5, 1, 1, 0));
|
||||||
|
edgeList.add(new TestEdge(3, 5, 1, 1, 0));
|
||||||
|
edgeList.add(new TestEdge(4, 5, 1, 1, 0));
|
||||||
|
|
||||||
|
// Test no matter the order of adding edges, min cost flow flows from 0 to 2 and then from 2 to 5
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
Collections.shuffle(edgeList);
|
||||||
|
final Graph<Integer> graph1 = new Graph<>();
|
||||||
|
for (final TestEdge edge : edgeList) {
|
||||||
|
graph1.addEdge(edge.source, edge.destination, edge.capacity, edge.cost, edge.flow);
|
||||||
|
}
|
||||||
|
graph1.setSourceNode(0);
|
||||||
|
graph1.setSinkNode(5);
|
||||||
|
assertEquals(3, graph1.totalCost());
|
||||||
|
|
||||||
|
graph1.solveMinCostFlow();
|
||||||
|
assertEquals(2, graph1.totalCost());
|
||||||
|
|
||||||
|
Map<Integer, Graph<Integer>.Edge> edges = graph1.edges(0);
|
||||||
|
assertEquals(4, edges.size());
|
||||||
|
assertEquals(getEdge(1, 1, 2, 1, 0), edges.get(1));
|
||||||
|
assertEquals(getEdge(2, 1, 1, 0, 1), edges.get(2));
|
||||||
|
assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3));
|
||||||
|
assertEquals(getEdge(4, 1, 1, 1, 0), edges.get(4));
|
||||||
|
|
||||||
|
edges = graph1.edges(1);
|
||||||
|
assertEquals(1, edges.size());
|
||||||
|
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
|
||||||
|
|
||||||
|
edges = graph1.edges(2);
|
||||||
|
assertEquals(1, edges.size());
|
||||||
|
assertEquals(getEdge(5, 1, 1, 0, 1), edges.get(5));
|
||||||
|
|
||||||
|
edges = graph1.edges(3);
|
||||||
|
assertEquals(1, edges.size());
|
||||||
|
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
|
||||||
|
|
||||||
|
edges = graph1.edges(4);
|
||||||
|
assertEquals(1, edges.size());
|
||||||
|
assertEquals(getEdge(5, 1, 1, 1, 0), edges.get(5));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Graph<Integer>.Edge getEdge(final int destination, final int capacity, final int cost, final int residualFlow, final int flow) {
|
||||||
|
return getEdge(destination, capacity, cost, residualFlow, flow, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Graph<Integer>.Edge getEdge(final int destination, final int capacity, final int cost, final int residualFlow, final int flow, final boolean forwardEdge) {
|
||||||
|
return new Graph<Integer>().new Edge(destination, capacity, cost, residualFlow, flow, forwardEdge);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class TestEdge {
|
||||||
|
final int source;
|
||||||
|
final int destination;
|
||||||
|
final int capacity;
|
||||||
|
final int cost;
|
||||||
|
final int flow;
|
||||||
|
|
||||||
|
TestEdge(final int source, final int destination, final int capacity, final int cost, final int flow) {
|
||||||
|
this.source = source;
|
||||||
|
this.destination = destination;
|
||||||
|
this.capacity = capacity;
|
||||||
|
this.cost = cost;
|
||||||
|
this.flow = flow;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,7 +40,6 @@ import org.apache.kafka.common.utils.MockTime;
|
||||||
import org.apache.kafka.streams.StreamsConfig;
|
import org.apache.kafka.streams.StreamsConfig;
|
||||||
import org.apache.kafka.streams.StreamsConfig.InternalConfig;
|
import org.apache.kafka.streams.StreamsConfig.InternalConfig;
|
||||||
import org.apache.kafka.streams.processor.TaskId;
|
import org.apache.kafka.streams.processor.TaskId;
|
||||||
import org.apache.kafka.streams.processor.internals.RackAwareTaskAssignor;
|
|
||||||
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
|
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
|
||||||
import org.apache.kafka.test.MockClientSupplier;
|
import org.apache.kafka.test.MockClientSupplier;
|
||||||
import org.apache.kafka.test.MockInternalTopicManager;
|
import org.apache.kafka.test.MockInternalTopicManager;
|
||||||
|
|
Loading…
Reference in New Issue