You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hugegraph.apache.org by ji...@apache.org on 2022/11/09 10:25:31 UTC
[incubator-hugegraph] 28/33: add StressCentrality v2 (#65)
This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph.git
commit 476477bb7b0f8c2e85e3c875b9d45be68266ec49
Author: houzhizhen <ho...@163.com>
AuthorDate: Wed Oct 14 16:58:55 2020 +0800
add StressCentrality v2 (#65)
* add StressCentralityAlgorithmV2
* add BfsTraverser and ClosenessCentralityAlgorithmV2
---
.../hugegraph/job/algorithm/AlgorithmPool.java | 7 +
.../hugegraph/job/algorithm/BfsTraverser.java | 150 ++++++++++++++++
.../cent/BetweennessCentralityAlgorithmV2.java | 192 +++++++--------------
.../cent/ClosenessCentralityAlgorithmV2.java | 135 +++++++++++++++
.../cent/StressCentralityAlgorithmV2.java | 182 +++++++++++++++++++
5 files changed, 539 insertions(+), 127 deletions(-)
diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java
index 9ded512ef..02ac4c24e 100644
--- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java
@@ -23,10 +23,13 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import com.baidu.hugegraph.job.algorithm.cent.BetweennessCentralityAlgorithm;
+import com.baidu.hugegraph.job.algorithm.cent.BetweennessCentralityAlgorithmV2;
import com.baidu.hugegraph.job.algorithm.cent.ClosenessCentralityAlgorithm;
+import com.baidu.hugegraph.job.algorithm.cent.ClosenessCentralityAlgorithmV2;
import com.baidu.hugegraph.job.algorithm.cent.DegreeCentralityAlgorithm;
import com.baidu.hugegraph.job.algorithm.cent.EigenvectorCentralityAlgorithm;
import com.baidu.hugegraph.job.algorithm.cent.StressCentralityAlgorithm;
+import com.baidu.hugegraph.job.algorithm.cent.StressCentralityAlgorithmV2;
import com.baidu.hugegraph.job.algorithm.comm.ClusterCoeffcientAlgorithm;
import com.baidu.hugegraph.job.algorithm.comm.KCoreAlgorithm;
import com.baidu.hugegraph.job.algorithm.comm.LouvainAlgorithm;
@@ -65,6 +68,10 @@ public class AlgorithmPool {
INSTANCE.register(new PageRankAlgorithm());
INSTANCE.register(new SubgraphStatAlgorithm());
+
+ INSTANCE.register(new StressCentralityAlgorithmV2());
+ INSTANCE.register(new BetweennessCentralityAlgorithmV2());
+ INSTANCE.register(new ClosenessCentralityAlgorithmV2());
}
private final Map<String, Algorithm> algorithms;
diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/BfsTraverser.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/BfsTraverser.java
new file mode 100644
index 000000000..034887f27
--- /dev/null
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/BfsTraverser.java
@@ -0,0 +1,150 @@
+/*
+ * Copyright 2017 HugeGraph Authors
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.baidu.hugegraph.job.algorithm;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.Stack;
+
+import org.apache.tinkerpop.gremlin.structure.Edge;
+
+import com.baidu.hugegraph.backend.id.Id;
+import com.baidu.hugegraph.job.UserJob;
+import com.baidu.hugegraph.structure.HugeEdge;
+import com.baidu.hugegraph.type.define.Directions;
+
+public abstract class BfsTraverser<T extends BfsTraverser.Node>
+ extends AbstractAlgorithm.AlgoTraverser
+ implements AutoCloseable {
+
+ private Stack<Id> traversedVertices = new Stack<>();
+
+ public BfsTraverser(UserJob<Object> job) {
+ super(job);
+ }
+
+ protected void compute(Id startVertex, Directions direction,
+ Id edgeLabel, long degree, long depth) {
+ Map<Id, T> localNodes = this.traverse(startVertex, direction,
+ edgeLabel, degree, depth);
+ this.backtrack(startVertex, localNodes);
+ }
+
+ protected Map<Id, T> traverse(Id startVertex, Directions direction,
+ Id edgeLabel, long degree, long depth) {
+ Map<Id, T> localNodes = new HashMap<>();
+ localNodes.put(startVertex, this.createStartNode());
+
+ LinkedList<Id> traversingVertices = new LinkedList<>();
+ traversingVertices.add(startVertex);
+ while (!traversingVertices.isEmpty()) {
+ Id source = traversingVertices.removeFirst();
+ this.traversedVertices.push(source);
+ T sourceNode = localNodes.get(source);
+ if (depth != NO_LIMIT && sourceNode.distance() >= depth) {
+ continue;
+ }
+ // TODO: sample the edges
+ Iterator<Edge> edges = this.edgesOfVertex(source, direction,
+ edgeLabel, degree);
+ while (edges.hasNext()) {
+ HugeEdge edge = (HugeEdge) edges.next();
+ Id target = edge.otherVertex().id();
+ T targetNode = localNodes.get(target);
+ boolean firstTime = false;
+ // Edge's targetNode is arrived at first time
+ if (targetNode == null) {
+ firstTime = true;
+ targetNode = this.createNode(sourceNode);
+ localNodes.put(target, targetNode);
+ traversingVertices.addLast(target);
+ }
+ this.meetNode(target, targetNode, source,
+ sourceNode, firstTime);
+ }
+ }
+ return localNodes;
+ }
+
+ protected void backtrack(Id startVertex, Map<Id, T> localNodes) {
+ while (!this.traversedVertices.empty()) {
+ Id currentVertex = this.traversedVertices.pop();
+ this.backtrack(startVertex, currentVertex, localNodes);
+ }
+ }
+
+ protected abstract T createStartNode();
+
+ protected abstract T createNode(T parentNode);
+
+ protected abstract void meetNode(Id currentVertex, T currentNode,
+ Id parentVertex, T parentNode,
+ boolean firstTime);
+
+ protected abstract void backtrack(Id startVertex, Id currentVertex,
+ Map<Id, T> localNodes);
+
+ public static class Node {
+
+ private Id[] parents;
+ private int pathCount;
+ private int distance;
+
+ public Node(Node parentNode) {
+ this(0, parentNode.distance + 1);
+ }
+
+ public Node(int pathCount, int distance) {
+ this.pathCount = pathCount;
+ this.distance = distance;
+ this.parents = new Id[0];
+ }
+
+ public int distance() {
+ return this.distance;
+ }
+
+ public Id[] parents() {
+ return this.parents;
+ }
+
+ public void addParent(Id parentId) {
+ // TODO: test if need to allocate more memory in advance
+ Id[] newParents = new Id[this.parents.length + 1];
+ System.arraycopy(this.parents, 0, newParents, 0,
+ this.parents.length);
+ newParents[newParents.length - 1] = parentId;
+ this.parents = newParents;
+ }
+
+ public void addParentNodeIfNeeded(Node node, Id parentId) {
+ if (this.distance == node.distance + 1) {
+ this.pathCount += node.pathCount;
+ this.addParent(parentId);
+ }
+ }
+
+ protected int pathCount() {
+ return this.pathCount;
+ }
+ }
+}
diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/BetweennessCentralityAlgorithmV2.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/BetweennessCentralityAlgorithmV2.java
index f5f7e4db6..1391021a2 100644
--- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/BetweennessCentralityAlgorithmV2.java
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/BetweennessCentralityAlgorithmV2.java
@@ -21,21 +21,19 @@ package com.baidu.hugegraph.job.algorithm.cent;
import java.util.HashMap;
import java.util.Iterator;
-import java.util.LinkedList;
import java.util.Map;
-import java.util.Stack;
+import org.apache.commons.lang3.mutable.MutableFloat;
import org.apache.tinkerpop.gremlin.structure.Vertex;
import com.baidu.hugegraph.backend.id.Id;
import com.baidu.hugegraph.backend.query.Query;
import com.baidu.hugegraph.job.UserJob;
-import com.baidu.hugegraph.structure.HugeEdge;
+import com.baidu.hugegraph.job.algorithm.BfsTraverser;
import com.baidu.hugegraph.structure.HugeVertex;
import com.baidu.hugegraph.traversal.algorithm.HugeTraverser;
import com.baidu.hugegraph.type.define.Directions;
-
public class BetweennessCentralityAlgorithmV2 extends AbstractCentAlgorithm {
@Override
@@ -63,29 +61,31 @@ public class BetweennessCentralityAlgorithmV2 extends AbstractCentAlgorithm {
}
}
- private static class Traverser extends AbstractCentAlgorithm.Traverser {
+ private static class Traverser extends BfsTraverser<BetweennessNode> {
+
+ private Map<Id, MutableFloat> globalBetweennesses;
private Traverser(UserJob<Object> job) {
super(job);
}
private Object betweenessCentrality(Directions direction,
- String label,
- int depth,
- long degree,
- long sample,
- String sourceLabel,
- long sourceSample,
- String sourceCLabel,
- long topN) {
+ String label,
+ int depth,
+ long degree,
+ long sample,
+ String sourceLabel,
+ long sourceSample,
+ String sourceCLabel,
+ long topN) {
assert depth > 0;
assert degree > 0L || degree == NO_LIMIT;
assert topN >= 0L || topN == NO_LIMIT;
- Map<Id, Float> globalBetweennesses = new HashMap<>();
+ this.globalBetweennesses = new HashMap<>();
Id edgeLabelId = null;
if (label != null) {
- edgeLabelId = graph().edgeLabel(label).id();
+ edgeLabelId = this.graph().edgeLabel(label).id();
}
// TODO: sample the startVertices
@@ -93,144 +93,82 @@ public class BetweennessCentralityAlgorithmV2 extends AbstractCentAlgorithm {
sourceCLabel,
Query.NO_LIMIT);
while (startVertices.hasNext()) {
- Id startVertex = ((HugeVertex) startVertices.next()).id();
- globalBetweennesses.putIfAbsent(startVertex, 0.0f);
- Stack<Id> traversedVertices = new Stack<>();
- Map<Id, BetweennessNode> localBetweennesses = new HashMap<>();
- BetweennessNode startNode = new BetweennessNode(1, 0);
- localBetweennesses.put(startVertex, startNode);
- this.computeDistance(startVertex, localBetweennesses,
- traversedVertices, direction,
- edgeLabelId, depth, degree);
- this.computeBetweenness(startVertex, traversedVertices,
- globalBetweennesses,
- localBetweennesses);
+ Id startVertex = ((HugeVertex) startVertices.next()).id();
+ this.globalBetweennesses.putIfAbsent(startVertex,
+ new MutableFloat());
+ this.compute(startVertex, direction, edgeLabelId,
+ degree, depth);
}
- if (topN > 0) {
- return HugeTraverser.topN(globalBetweennesses, true, topN);
+ if (topN > 0L || topN == NO_LIMIT) {
+ return HugeTraverser.topN(this.globalBetweennesses,
+ true, topN);
} else {
- return globalBetweennesses;
+ return this.globalBetweennesses;
}
}
- private void computeDistance(Id startVertex,
- Map<Id, BetweennessNode> betweennesses,
- Stack<Id> traversedVertices, Directions direction,
- Id edgeLabelId, long degree, long depth) {
- LinkedList<Id> traversingVertices = new LinkedList<>();
- traversingVertices.add(startVertex);
-
- while (!traversingVertices.isEmpty()) {
- Id source = traversingVertices.removeFirst();
- traversedVertices.push(source);
- BetweennessNode sourceNode = betweennesses.get(source);
- if (sourceNode == null) {
- sourceNode = new BetweennessNode();
- betweennesses.put(source, sourceNode);
- }
- // TODO: sample the edges
- Iterator<HugeEdge> edges = (Iterator) this.edgesOfVertex(
- source, direction, edgeLabelId,
- degree);
- while (edges.hasNext()) {
- HugeEdge edge = edges.next();
- Id targetId = edge.otherVertex().id();
- BetweennessNode targetNode = betweennesses.get(targetId);
- // edge's targetNode is arrived at first time
- if (targetNode == null) {
- targetNode = new BetweennessNode(sourceNode);
- betweennesses.put(targetId, targetNode);
- if (depth == NO_LIMIT ||
- targetNode.distance() <= depth) {
- traversingVertices.addLast(targetId);
- }
- }
- targetNode.addParentNodeIfNeeded(sourceNode, source);
- }
- }
+ @Override
+ protected BetweennessNode createNode(BetweennessNode parentNode) {
+ return new BetweennessNode(parentNode);
}
- private void computeBetweenness(
- Id startVertex,
- Stack<Id> traversedVertices,
- Map<Id, Float> globalBetweennesses,
- Map<Id, BetweennessNode> localBetweennesses) {
- while (!traversedVertices.empty()) {
- Id currentId = traversedVertices.pop();
- BetweennessNode currentNode =
- localBetweennesses.get(currentId);
- if (currentId.equals(startVertex)) {
- continue;
- }
- // add to globalBetweennesses
- float betweenness = globalBetweennesses.getOrDefault(currentId,
- 0.0f);
- betweenness += currentNode.betweenness();
- globalBetweennesses.put(currentId, betweenness);
-
- // contribute to parent
- for (Id v : currentNode.parents()) {
- BetweennessNode parentNode = localBetweennesses.get(v);
- parentNode.increaseBetweenness(currentNode);
- }
+ @Override
+ protected void meetNode(Id currentVertex, BetweennessNode currentNode,
+ Id parentVertex, BetweennessNode parentNode,
+ boolean firstTime) {
+ currentNode.addParentNodeIfNeeded(parentNode, parentVertex);
+ }
+
+ @Override
+ protected BetweennessNode createStartNode() {
+ return new BetweennessNode(1, 0);
+ }
+
+ @Override
+ protected void backtrack(Id startVertex, Id currentVertex,
+ Map<Id, BetweennessNode> localNodes) {
+ if (startVertex.equals(currentVertex)) {
+ return;
+ }
+ MutableFloat betweenness = this.globalBetweennesses.get(
+ currentVertex);
+ if (betweenness == null) {
+ betweenness = new MutableFloat(0.0F);
+ this.globalBetweennesses.put(currentVertex, betweenness);
+ }
+ BetweennessNode node = localNodes.get(currentVertex);
+ betweenness.add(node.betweenness());
+
+ // Contribute to parents
+ for (Id v : node.parents()) {
+ BetweennessNode parentNode = localNodes.get(v);
+ parentNode.increaseBetweenness(node);
}
}
}
/**
- * the temp data structure for a vertex used in computing process.
+ * Temp data structure for a vertex used in computing process.
*/
- private static class BetweennessNode {
+ private static class BetweennessNode extends BfsTraverser.Node {
- private Id[] parents;
- private int pathCount;
- private int distance;
private float betweenness;
- public BetweennessNode() {
- this(0, -1);
- }
-
public BetweennessNode(BetweennessNode parentNode) {
- this(0, parentNode.distance + 1);
+ this(0, parentNode.distance() + 1);
}
public BetweennessNode(int pathCount, int distance) {
- this.pathCount = pathCount;
- this.distance = distance;
- this.parents = new Id[0];
- this.betweenness = 0.0f;
- }
-
- public int distance() {
- return this.distance;
- }
-
- public Id[] parents() {
- return this.parents;
- }
-
- public void addParent(Id parentId) {
- Id[] newParents = new Id[this.parents.length + 1];
- System.arraycopy(this.parents, 0, newParents, 0,
- this.parents.length);
- newParents[newParents.length - 1] = parentId;
- this.parents = newParents;
+ super(pathCount, distance);
+ this.betweenness = 0.0F;
}
public void increaseBetweenness(BetweennessNode childNode) {
- float increase = (float) this.pathCount / childNode.pathCount *
- (1 + childNode.betweenness);
+ float increase = (float) this.pathCount() / childNode.pathCount() *
+ (1.0F + childNode.betweenness);
this.betweenness += increase;
}
- public void addParentNodeIfNeeded(BetweennessNode node, Id parentId) {
- if (this.distance == node.distance + 1) {
- this.pathCount += node.pathCount;
- this.addParent(parentId);
- }
- }
-
public float betweenness() {
return this.betweenness;
}
diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/ClosenessCentralityAlgorithmV2.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/ClosenessCentralityAlgorithmV2.java
new file mode 100644
index 000000000..1651c8943
--- /dev/null
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/ClosenessCentralityAlgorithmV2.java
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2017 HugeGraph Authors
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+package com.baidu.hugegraph.job.algorithm.cent;
+
+import com.baidu.hugegraph.backend.id.Id;
+import com.baidu.hugegraph.backend.query.Query;
+import com.baidu.hugegraph.exception.NotSupportException;
+import com.baidu.hugegraph.job.UserJob;
+import com.baidu.hugegraph.job.algorithm.BfsTraverser;
+import com.baidu.hugegraph.structure.HugeVertex;
+import com.baidu.hugegraph.traversal.algorithm.HugeTraverser;
+import com.baidu.hugegraph.type.define.Directions;
+import org.apache.tinkerpop.gremlin.structure.Vertex;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+public class ClosenessCentralityAlgorithmV2 extends AbstractCentAlgorithm {
+
+ @Override
+ public String name() {
+ return "closeness_centrality";
+ }
+
+ @Override
+ public void checkParameters(Map<String, Object> parameters) {
+ super.checkParameters(parameters);
+ }
+
+ @Override
+ public Object call(UserJob<Object> job, Map<String, Object> parameters) {
+ try (Traverser traverser = new Traverser(job)) {
+ return traverser.closenessCentrality(direction(parameters),
+ edgeLabel(parameters),
+ depth(parameters),
+ degree(parameters),
+ sample(parameters),
+ sourceLabel(parameters),
+ sourceSample(parameters),
+ sourceCLabel(parameters),
+ top(parameters));
+ }
+ }
+
+ private static class Traverser extends BfsTraverser<BfsTraverser.Node> {
+
+ private Map<Id, Float> globalCloseness;
+
+ private float startVertexCloseness;
+
+ private Traverser(UserJob<Object> job) {
+ super(job);
+ this.globalCloseness = new HashMap<>();
+ }
+
+ private Object closenessCentrality(Directions direction,
+ String label,
+ int depth,
+ long degree,
+ long sample,
+ String sourceLabel,
+ long sourceSample,
+ String sourceCLabel,
+ long topN) {
+ assert depth > 0;
+ assert degree > 0L || degree == NO_LIMIT;
+ assert topN >= 0L || topN == NO_LIMIT;
+
+ Id edgeLabelId = null;
+ if (label != null) {
+ edgeLabelId = this.graph().edgeLabel(label).id();
+ }
+
+ // TODO: sample the startVertices
+ Iterator<Vertex> startVertices = this.vertices(sourceLabel,
+ sourceCLabel,
+ Query.NO_LIMIT);
+ while (startVertices.hasNext()) {
+ this.startVertexCloseness = 0.0F;
+ Id startVertex = ((HugeVertex) startVertices.next()).id();
+ this.traverse(startVertex, direction, edgeLabelId,
+ degree, depth);
+ this.globalCloseness.put(startVertex,
+ this.startVertexCloseness);
+ }
+ if (topN > 0L || topN == NO_LIMIT) {
+ return HugeTraverser.topN(this.globalCloseness, true, topN);
+ } else {
+ return this.globalCloseness;
+ }
+ }
+
+ @Override
+ protected Node createStartNode() {
+ return new Node(1, 0);
+ }
+
+ @Override
+ protected Node createNode(Node parentNode) {
+ return new Node(parentNode);
+ }
+
+ @Override
+ protected void meetNode(Id currentVertex, Node currentNode,
+ Id parentVertex, Node parentNode,
+ boolean firstTime) {
+ if (firstTime) {
+ this.startVertexCloseness += 1.0F / currentNode.distance();
+ }
+ }
+
+ @Override
+ protected void backtrack(Id startVertex, Id currentVertex,
+ Map<Id, Node> localNodes) {
+ throw new NotSupportException("backtrack()");
+ }
+ }
+}
diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/StressCentralityAlgorithmV2.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/StressCentralityAlgorithmV2.java
new file mode 100644
index 000000000..b01486b7e
--- /dev/null
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/StressCentralityAlgorithmV2.java
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2017 HugeGraph Authors
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.baidu.hugegraph.job.algorithm.cent;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+import org.apache.commons.lang3.mutable.MutableLong;
+import org.apache.tinkerpop.gremlin.structure.Vertex;
+
+import com.baidu.hugegraph.backend.id.Id;
+import com.baidu.hugegraph.backend.query.Query;
+import com.baidu.hugegraph.job.UserJob;
+import com.baidu.hugegraph.job.algorithm.BfsTraverser;
+import com.baidu.hugegraph.structure.HugeVertex;
+import com.baidu.hugegraph.traversal.algorithm.HugeTraverser;
+import com.baidu.hugegraph.type.define.Directions;
+
+public class StressCentralityAlgorithmV2 extends AbstractCentAlgorithm {
+
+ @Override
+ public String name() {
+ return "stress_centrality";
+ }
+
+ @Override
+ public void checkParameters(Map<String, Object> parameters) {
+ super.checkParameters(parameters);
+ }
+
+ @Override
+ public Object call(UserJob<Object> job, Map<String, Object> parameters) {
+ try (Traverser traverser = new Traverser(job)) {
+ return traverser.stressCentrality(direction(parameters),
+ edgeLabel(parameters),
+ depth(parameters),
+ degree(parameters),
+ sample(parameters),
+ sourceLabel(parameters),
+ sourceSample(parameters),
+ sourceCLabel(parameters),
+ top(parameters));
+ }
+ }
+
+ private static class Traverser extends BfsTraverser<StressNode> {
+
+ private Map<Id, MutableLong> globalStresses;
+
+ private Traverser(UserJob<Object> job) {
+ super(job);
+ this.globalStresses = new HashMap<>();
+ }
+
+ private Object stressCentrality(Directions direction,
+ String label,
+ int depth,
+ long degree,
+ long sample,
+ String sourceLabel,
+ long sourceSample,
+ String sourceCLabel,
+ long topN) {
+ assert depth > 0;
+ assert degree > 0L || degree == NO_LIMIT;
+ assert topN >= 0L || topN == NO_LIMIT;
+
+ Id edgeLabelId = null;
+ if (label != null) {
+ edgeLabelId = this.graph().edgeLabel(label).id();
+ }
+
+ // TODO: sample the startVertices
+ Iterator<Vertex> startVertices = this.vertices(sourceLabel,
+ sourceCLabel,
+ Query.NO_LIMIT);
+ while (startVertices.hasNext()) {
+ Id startVertex = ((HugeVertex) startVertices.next()).id();
+ this.globalStresses.putIfAbsent(startVertex, new MutableLong(0L));
+ this.compute(startVertex, direction, edgeLabelId,
+ degree, depth);
+ }
+ if (topN > 0L || topN == NO_LIMIT) {
+ return HugeTraverser.topN(this.globalStresses, true, topN);
+ } else {
+ return this.globalStresses;
+ }
+ }
+
+ @Override
+ protected StressNode createStartNode() {
+ return new StressNode(1, 0);
+ }
+
+ @Override
+ protected StressNode createNode(StressNode parentNode) {
+ return new StressNode(parentNode);
+ }
+
+ @Override
+ protected void meetNode(Id currentVertex, StressNode currentNode,
+ Id parentVertex, StressNode parentNode,
+ boolean firstTime) {
+ currentNode.addParentNodeIfNeeded(parentNode, parentVertex);
+ }
+
+ @Override
+ protected void backtrack(Id startVertex, Id currentVertex,
+ Map<Id, StressNode> localNodes) {
+ if (startVertex.equals(currentVertex)) {
+ return;
+ }
+ StressNode currentNode = localNodes.get(currentVertex);
+
+ // Add local stresses to global stresses
+ MutableLong stress = this.globalStresses.get(currentVertex);
+ if (stress == null) {
+ stress = new MutableLong(0L);
+ this.globalStresses.put(currentVertex, stress);
+ }
+ stress.add(currentNode.stress());
+
+ // Contribute to parents
+ for (Id v : currentNode.parents()) {
+ StressNode parentNode = localNodes.get(v);
+ parentNode.increaseStress(currentNode);
+ }
+ }
+ }
+
+ /**
+ * Temp data structure for a vertex used in computing process.
+ */
+ private static class StressNode extends BfsTraverser.Node {
+
+ private long stress;
+
+ public StressNode(StressNode parentNode) {
+ this(0, parentNode.distance() + 1);
+ }
+
+ public StressNode(int pathCount, int distance) {
+ super(pathCount, distance);
+ this.stress = 0L;
+ }
+
+ public void increaseStress(StressNode childNode) {
+ /*
+ * `childNode.stress` is the contribution after child node.
+ * `childNode.pathCount` is the contribution of the child node.
+ * The sum of them is contribution to current node, there may be
+ * multi parents node of the child node, so contribute to current
+ * node proportionally.
+ */
+ long total = childNode.stress + childNode.pathCount();
+ long received = total * this.pathCount() / childNode.pathCount();
+ this.stress += received;
+ }
+
+ public long stress() {
+ return this.stress;
+ }
+ }
+}