You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hugegraph.apache.org by ji...@apache.org on 2022/11/09 10:25:10 UTC

[incubator-hugegraph] 07/33: louvain: add modularity parameter and fix isolated community lost (#14)

This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph.git

commit 9c9c26c3cb8aeb6090f08dcbcfc310309c034b31
Author: Jermy Li <li...@baidu.com>
AuthorDate: Wed May 6 15:15:53 2020 +0800

    louvain: add modularity parameter and fix isolated community lost (#14)
    
    * add modularity parameter for louvain
    * fix: louvain lost isolated community from one to next pass
    
    Change-Id: I6a7dadc80635429aa2898939aa337aae01bc8d12
---
 .../hugegraph/job/algorithm/AbstractAlgorithm.java |   3 +-
 .../job/algorithm/comm/LouvainAlgorithm.java       |  20 ++-
 .../job/algorithm/comm/LouvainTraverser.java       | 187 ++++++++++++++-------
 3 files changed, 145 insertions(+), 65 deletions(-)

diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java
index 248a92bdb..969bda1d8 100644
--- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java
@@ -59,7 +59,7 @@ import jersey.repackaged.com.google.common.base.Objects;
 public abstract class AbstractAlgorithm implements Algorithm {
 
     public static final long MAX_RESULT_SIZE = 100L * Bytes.MB;
-    public static final long MAX_QUERY_LIMIT = 10000000L; // about 10GB
+    public static final long MAX_QUERY_LIMIT = 100000000L; // about 100GB
     public static final int BATCH = 500;
 
     public static final String CATEGORY_AGGR = "aggregate";
@@ -81,6 +81,7 @@ public abstract class AbstractAlgorithm implements Algorithm {
     public static final String KEY_TIMES = "times";
     public static final String KEY_STABLE_TIMES = "stable_times";
     public static final String KEY_PRECISION = "precision";
+    public static final String KEY_SHOW_MOD= "show_modularity";
     public static final String KEY_SHOW_COMM = "show_community";
     public static final String KEY_CLEAR = "clear";
     public static final String KEY_CAPACITY = "capacity";
diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java
index 3f6de63e8..c0c05f9a2 100644
--- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java
@@ -22,7 +22,7 @@ package com.baidu.hugegraph.job.algorithm.comm;
 import java.util.Map;
 
 import com.baidu.hugegraph.job.Job;
-import com.baidu.hugegraph.util.E;
+import com.baidu.hugegraph.traversal.algorithm.HugeTraverser;
 
 public class LouvainAlgorithm extends AbstractCommAlgorithm {
 
@@ -39,6 +39,7 @@ public class LouvainAlgorithm extends AbstractCommAlgorithm {
         degree(parameters);
         sourceLabel(parameters);
         sourceCLabel(parameters);
+        showModularity(parameters);
         showCommunity(parameters);
         clearPass(parameters);
     }
@@ -52,10 +53,13 @@ public class LouvainAlgorithm extends AbstractCommAlgorithm {
         LouvainTraverser traverser = new LouvainTraverser(job, degree,
                                                           label, clabel);
         Long clearPass = clearPass(parameters);
+        Long modPass = showModularity(parameters);
         String showComm = showCommunity(parameters);
         try {
             if (clearPass != null) {
                 return traverser.clearPass(clearPass.intValue());
+            } else if (modPass != null) {
+                return traverser.modularity(modPass.intValue());
             } else if (showComm != null) {
                 return traverser.showCommunity(showComm);
             } else {
@@ -74,10 +78,16 @@ public class LouvainAlgorithm extends AbstractCommAlgorithm {
             return null;
         }
         long pass = parameterLong(parameters, KEY_CLEAR);
-        // TODO: change to checkNonNegative()
-        E.checkArgument(pass >= 0 || pass == -1,
-                        "The %s parameter must be >= 0 or == -1, but got %s",
-                        KEY_CLEAR, pass);
+        HugeTraverser.checkNonNegativeOrNoLimit(pass, KEY_CLEAR);
+        return pass;
+    }
+
+    protected static Long showModularity(Map<String, Object> parameters) {
+        if (!parameters.containsKey(KEY_SHOW_MOD)) {
+            return null;
+        }
+        long pass = parameterLong(parameters, KEY_SHOW_MOD);
+        HugeTraverser.checkNonNegative(pass, KEY_SHOW_MOD);
         return pass;
     }
 }
diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java
index 0177d8f2d..a63a1259d 100644
--- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java
+++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java
@@ -28,6 +28,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.NoSuchElementException;
 import java.util.Set;
 
 import org.apache.commons.lang3.mutable.MutableInt;
@@ -52,6 +53,7 @@ import com.baidu.hugegraph.schema.VertexLabel;
 import com.baidu.hugegraph.structure.HugeEdge;
 import com.baidu.hugegraph.structure.HugeVertex;
 import com.baidu.hugegraph.type.define.Directions;
+import com.baidu.hugegraph.util.InsertionOrderUtil;
 import com.baidu.hugegraph.util.Log;
 import com.google.common.collect.ImmutableMap;
 
@@ -89,23 +91,6 @@ public class LouvainTraverser extends AlgoTraverser {
         this.cache = new Cache();
     }
 
-    @SuppressWarnings("unused")
-    private Id genId2(int pass, Id cid) {
-        // gen id for merge-community vertex
-        String id = cid.toString();
-        if (pass == 0) {
-            // conncat pass with cid
-            id = pass + "~" + id;
-        } else {
-            // replace last pass with current pass
-            String lastPass = String.valueOf(pass - 1);
-            assert id.startsWith(lastPass);
-            id = id.substring(lastPass.length());
-            id = pass + id;
-        }
-        return IdGenerator.of(id);
-    }
-
     private void defineSchemaOfPk() {
         String label = this.labelOfPassN(0);
         if (this.graph().existsVertexLabel(label) ||
@@ -131,8 +116,7 @@ public class LouvainTraverser extends AlgoTraverser {
         SchemaManager schema = this.graph().schema();
         try {
             schema.vertexLabel(this.passLabel).useCustomizeStringId()
-                  .properties(C_KIN, C_MEMBERS)
-                  .nullableKeys(C_KIN, C_MEMBERS)
+                  .properties(C_KIN, C_MEMBERS, C_WEIGHT)
                   .create();
             schema.edgeLabel(this.passLabel)
                   .sourceLabel(this.passLabel)
@@ -189,9 +173,16 @@ public class LouvainTraverser extends AlgoTraverser {
         return weight;
     }
 
-    private Vertex newCommunityNode(Id cid, int kin, List<String> members) {
+    private Vertex newCommunityNode(Id cid, float cweight,
+                                    int kin, List<String> members) {
         assert !members.isEmpty() : members;
-        return this.graph().addVertex(T.label, this.passLabel, T.id, cid,
+        /*
+         * cweight: members size(all pass) of the community, just for debug
+         * kin: edges weight in the community
+         * members: members id of the community of last pass
+         */
+        return this.graph().addVertex(T.label, this.passLabel,
+                                      T.id, cid, C_WEIGHT, cweight,
                                       C_KIN, kin, C_MEMBERS, members);
     }
 
@@ -204,12 +195,12 @@ public class LouvainTraverser extends AlgoTraverser {
         return source.addEdge(this.passLabel, target, C_WEIGHT, weight);
     }
 
-    private void insertNewCommunity(int pass, Id cid, int kin,
-                                    List<String> members,
+    private void insertNewCommunity(int pass, Id cid, float cweight,
+                                    int kin, List<String> members,
                                     Map<Id, MutableInt> cedges) {
         // create backend vertex if it's the first time
         Id vid = this.cache.genId(pass, cid);
-        Vertex node = this.newCommunityNode(vid, kin, members);
+        Vertex node = this.newCommunityNode(vid, cweight, kin, members);
         commitIfNeeded();
         // update backend vertex edges
         for (Map.Entry<Id, MutableInt> e : cedges.entrySet()) {
@@ -262,6 +253,7 @@ public class LouvainTraverser extends AlgoTraverser {
     }
 
     private float weightOfVertex(Vertex v, List<Edge> edges) {
+        // degree/weight of vertex
         Float value = this.cache.vertexWeight((Id) v.id());
         if (value != null) {
             return value;
@@ -281,9 +273,21 @@ public class LouvainTraverser extends AlgoTraverser {
         return 0;
     }
 
-    private Id cidOfVertex(Vertex v) {
+    private float cweightOfVertex(Vertex v) {
+        if (v.label().startsWith(C_PASS) && v.property(C_WEIGHT).isPresent()) {
+            return v.value(C_WEIGHT);
+        }
+        return 1f;
+    }
+
+    private Id cidOfVertex(Vertex v, List<Edge> nbs) {
         Id vid = (Id) v.id();
         Community c = this.cache.vertex2Community(vid);
+        // ensure source vertex exist in cache
+        if (c == null) {
+            c = this.wrapCommunity(v, nbs);
+            assert c != null;
+        }
         return c != null ? c.cid : vid;
     }
 
@@ -292,15 +296,15 @@ public class LouvainTraverser extends AlgoTraverser {
     //    and save as community vertex when merge()
     // 3: wrap community vertex as community node,
     //    and repeat step 2 and step 3.
-    private Community wrapCommunity(Vertex otherV) {
-        Id vid = (Id) otherV.id();
+    private Community wrapCommunity(Vertex v, List<Edge> nbs) {
+        Id vid = (Id) v.id();
         Community comm = this.cache.vertex2Community(vid);
         if (comm != null) {
             return comm;
         }
 
         comm = new Community(vid);
-        comm.add(this, otherV, null); // will traverse the neighbors of otherV
+        comm.add(this, v, nbs);
         this.cache.vertex2Community(vid, comm);
         return comm;
     }
@@ -316,7 +320,8 @@ public class LouvainTraverser extends AlgoTraverser {
                 // skip the old intermediate data, or filter clabel
                 continue;
             }
-            Community c = wrapCommunity(otherV);
+            // will traverse the neighbors of otherV
+            Community c = this.wrapCommunity(otherV, null);
             if (!comms.containsKey(c.cid)) {
                 comms.put(c.cid, Pair.of(c, new MutableInt(0)));
             }
@@ -359,8 +364,8 @@ public class LouvainTraverser extends AlgoTraverser {
                 continue;
             }
             total++;
-            Id cid = cidOfVertex(v);
             List<Edge> nbs = neighbors((Id) v.id());
+            Id cid = cidOfVertex(v, nbs);
             double ki = kinOfVertex(v) + weightOfVertex(v, nbs);
             // update community of v if △Q changed
             double maxDeltaQ = 0d;
@@ -377,13 +382,13 @@ public class LouvainTraverser extends AlgoTraverser {
                 // weight between c and otherC
                 double kiin = nbc.getRight().floatValue();
                 // weight of otherC
-                int tot = otherC.kin() + otherC.kout();
+                double tot = otherC.kin() + otherC.kout();
                 if (cid.equals(otherC.cid)) {
                     tot -= ki;
-                    assert tot >= 0;
+                    assert tot >= 0d;
                     // expect tot >= 0, but may be something wrong?
-                    if (tot < 0) {
-                        tot = 0;
+                    if (tot < 0d) {
+                        tot = 0d;
                     }
                 }
                 double deltaQ = kiin - ki * tot / this.m;
@@ -407,6 +412,7 @@ public class LouvainTraverser extends AlgoTraverser {
     private void mergeCommunities(int pass) {
         // merge each community as a vertex
         Collection<Pair<Community, Set<Id>>> comms = this.cache.communities();
+        assert this.allMembersExist(comms, pass -1);
         this.cache.resetVertexWeight();
         for (Pair<Community, Set<Id>> pair : comms) {
             Community c = pair.getKey();
@@ -417,6 +423,7 @@ public class LouvainTraverser extends AlgoTraverser {
             int kin = c.kin();
             Set<Id> vertices = pair.getRight();
             assert !vertices.isEmpty();
+            assert vertices.size() == c.size();
             List<String> members = new ArrayList<>(vertices.size());
             Map<Id, MutableInt> cedges = new HashMap<>(vertices.size());
             for (Id v : vertices) {
@@ -432,7 +439,8 @@ public class LouvainTraverser extends AlgoTraverser {
                         kin += weightOfEdge(edge);
                         continue;
                     }
-                    Id otherCid = cidOfVertex(otherV);
+                    assert this.cache.vertex2Community(otherV.id()) != null;
+                    Id otherCid = cidOfVertex(otherV, null);
                     if (otherCid.compareTo(c.cid) < 0) {
                         // skip if it should be collected by otherC
                         continue;
@@ -440,17 +448,33 @@ public class LouvainTraverser extends AlgoTraverser {
                     if (!cedges.containsKey(otherCid)) {
                         cedges.put(otherCid, new MutableInt(0));
                     }
+                    // update edge weight
                     cedges.get(otherCid).add(weightOfEdge(edge));
                 }
             }
             // insert new community vertex and edges into storage
-            this.insertNewCommunity(pass, c.cid, kin, members, cedges);
+            this.insertNewCommunity(pass, c.cid, c.weight(), kin, members, cedges);
         }
         this.graph().tx().commit();
         // reset communities
         this.cache.reset();
     }
 
+    private boolean allMembersExist(Collection<Pair<Community, Set<Id>>> comms,
+                                    int pass) {
+        String lastLabel = labelOfPassN(pass);
+        GraphTraversal<Vertex, Object> t = pass < 0 ? this.g.V().id() :
+                                           this.g.V().hasLabel(lastLabel).id();
+        Set<Object> all = this.execute(t, t::toSet);
+        for (Pair<Community, Set<Id>> comm : comms) {
+            all.removeAll(comm.getRight());
+        }
+        if (all.size() > 0) {
+            LOG.warn("Lost members of last pass: {}", all);
+        }
+        return all.isEmpty();
+    }
+
     public Object louvain(int maxTimes, int stableTimes, double precision) {
         assert maxTimes > 0;
         assert precision > 0d;
@@ -496,31 +520,40 @@ public class LouvainTraverser extends AlgoTraverser {
             }
         }
 
-        long communities = 0L;
+        Map<String, Object> results = InsertionOrderUtil.newMap();
+        results.putAll(ImmutableMap.of("pass_times", times,
+                                       "phase1_times", movedTimes,
+                                       "last_precision", movedPercent,
+                                       "times", maxTimes));
+        Number communities = 0L;
+        Number modularity = -1L;
         String commLabel = this.passLabel;
         if (!commLabel.isEmpty()) {
-            GraphTraversal<?, Long> t = this.g.V().hasLabel(commLabel).count();
-            communities = this.execute(t, t::next);
+            communities = tryNext(this.g.V().hasLabel(commLabel).count());
+            modularity = this.modularity(commLabel);
         }
-        return ImmutableMap.of("pass_times", times,
-                               "phase1_times", movedTimes,
-                               "last_precision", movedPercent,
-                               "times", maxTimes,
-                               "communities", communities);
+        results.putAll(ImmutableMap.of("communities", communities,
+                                       "modularity", modularity));
+        return results;
     }
 
     public double modularity(int pass) {
-        // pass: label the last pass
+        // community vertex label of one pass
         String label = labelOfPassN(pass);
-        Number kin = this.g.V().hasLabel(label).values(C_KIN).sum().next();
-        Number weight = this.g.E().hasLabel(label).values(C_WEIGHT).sum().next();
+        return this.modularity(label);
+    }
+
+    private double modularity(String label) {
+        // label: community vertex label of one pass
+        Number kin = tryNext(this.g.V().hasLabel(label).values(C_KIN).sum());
+        Number weight = tryNext(this.g.E().hasLabel(label).values(C_WEIGHT).sum());
         double m = kin.intValue() + weight.floatValue() * 2.0d;
         double q = 0.0d;
-        Iterator<Vertex> coms = this.g.V().hasLabel(label);
-        while (coms.hasNext()) {
-            Vertex com = coms.next();
-            int cin = com.value(C_KIN);
-            Number cout = this.g.V(com).bothE().values(C_WEIGHT).sum().next();
+        Iterator<Vertex> comms = this.vertices(label, LIMIT);
+        while (comms.hasNext()) {
+            Vertex comm = comms.next();
+            int cin = comm.value(C_KIN);
+            Number cout = tryNext(this.g.V(comm).bothE().values(C_WEIGHT).sum());
             double cdegree = cin + cout.floatValue();
             // Q = ∑(I/M - ((2I+O)/2M)^2)
             q += cin / m - Math.pow(cdegree / m, 2);
@@ -528,6 +561,16 @@ public class LouvainTraverser extends AlgoTraverser {
         return q;
     }
 
+    private <V extends Number> Number tryNext(GraphTraversal<?, V> iter) {
+        return this.execute(iter, () -> {
+            try {
+                return iter.next();
+            } catch (NoSuchElementException e) {
+                return 0;
+            }
+        });
+    }
+
     public Collection<Object> showCommunity(String community) {
         final String C_PASS0 = labelOfPassN(0);
         Collection<Object> comms = Arrays.asList(community);
@@ -604,8 +647,10 @@ public class LouvainTraverser extends AlgoTraverser {
 
         // community id (stored as a backend vertex)
         private final Id cid;
-        // community members size
+        // community members size of last pass [just for skip large community]
         private int size = 0;
+        // community members size of origin vertex [just for debug members lost]
+        private float weight = 0f;
         /*
          * weight of all edges in community(2X), sum of kin of new members
          *  [each is from the last pass, stored in backend vertex]
@@ -615,8 +660,7 @@ public class LouvainTraverser extends AlgoTraverser {
          * weight of all edges between communities, sum of kout of new members
          * [each is last pass, calculated in real time by neighbors]
          */
-        //
-        private int kout = 0;
+        private float kout = 0f;
 
         public Community(Id cid) {
             this.cid = cid;
@@ -630,14 +674,20 @@ public class LouvainTraverser extends AlgoTraverser {
             return this.size;
         }
 
+        public float weight() {
+            return this.weight;
+        }
+
         public void add(LouvainTraverser t, Vertex v, List<Edge> nbs) {
             this.size++;
+            this.weight += t.cweightOfVertex(v);
             this.kin += t.kinOfVertex(v);
             this.kout += t.weightOfVertex(v, nbs);
         }
 
         public void remove(LouvainTraverser t, Vertex v, List<Edge> nbs) {
             this.size--;
+            this.weight -= t.cweightOfVertex(v);
             this.kin -= t.kinOfVertex(v);
             this.kout -= t.weightOfVertex(v, nbs);
         }
@@ -646,14 +696,15 @@ public class LouvainTraverser extends AlgoTraverser {
             return this.kin;
         }
 
-        public int kout() {
+        public float kout() {
             return this.kout;
         }
 
         @Override
         public String toString() {
-            return String.format("[%s](size=%s kin=%s kout=%s)",
-                                 this.cid , this.size, this.kin, this.kout);
+            return String.format("[%s](size=%s weight=%s kin=%s kout=%s)",
+                                 this.cid , this.size, this.weight,
+                                 this.kin, this.kout);
         }
     }
 
@@ -669,7 +720,8 @@ public class LouvainTraverser extends AlgoTraverser {
             this.genIds = new HashMap<>();
         }
 
-        public Community vertex2Community(Id id) {
+        public Community vertex2Community(Object id) {
+            assert id instanceof Id;
             return this.vertex2Community.get(id);
         }
 
@@ -703,6 +755,23 @@ public class LouvainTraverser extends AlgoTraverser {
             return IdGenerator.of(id);
         }
 
+        @SuppressWarnings("unused")
+        public Id genId2(int pass, Id cid) {
+            // gen id for merge-community vertex
+            String id = cid.toString();
+            if (pass == 0) {
+                // conncat pass with cid
+                id = pass + "~" + id;
+            } else {
+                // replace last pass with current pass
+                String lastPass = String.valueOf(pass - 1);
+                assert id.startsWith(lastPass);
+                id = id.substring(lastPass.length());
+                id = pass + id;
+            }
+            return IdGenerator.of(id);
+        }
+
         public Collection<Pair<Community, Set<Id>>> communities(){
             // TODO: get communities from backend store instead of ram
             Map<Id, Pair<Community, Set<Id>>> comms = new HashMap<>();