You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/01 07:46:52 UTC

[GitHub] eric-haibin-lin closed pull request #13436: add graph_compact.

eric-haibin-lin closed pull request #13436: add graph_compact.
URL: https://github.com/apache/incubator-mxnet/pull/13436
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md
index 709ddae007c..d7c9021b595 100644
--- a/docs/api/python/ndarray/contrib.md
+++ b/docs/api/python/ndarray/contrib.md
@@ -61,6 +61,11 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
     index_copy
     getnnz
     edge_id
+    dgl_csr_neighbor_uniform_sample
+    dgl_csr_neighbor_non_uniform_sample
+    dgl_subgraph
+    dgl_adjacency
+    dgl_graph_compact
 ```
 
 ## API Reference
diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md
index c0a4da54cbd..35cd11c89a7 100644
--- a/docs/api/python/symbol/contrib.md
+++ b/docs/api/python/symbol/contrib.md
@@ -55,6 +55,17 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
     foreach
     while_loop
     cond
+    isinf
+    isfinite
+    isnan
+    index_copy
+    getnnz
+    edge_id
+    dgl_csr_neighbor_uniform_sample
+    dgl_csr_neighbor_non_uniform_sample
+    dgl_subgraph
+    dgl_adjacency
+    dgl_graph_compact
 ```
 
 ## API Reference
diff --git a/src/operator/contrib/dgl_graph.cc b/src/operator/contrib/dgl_graph.cc
index 74ad3d43564..ed7caacfdba 100644
--- a/src/operator/contrib/dgl_graph.cc
+++ b/src/operator/contrib/dgl_graph.cc
@@ -768,7 +768,10 @@ static void CSRNeighborUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
 NNVM_REGISTER_OP(_contrib_dgl_csr_neighbor_uniform_sample)
 .describe(R"code(This operator samples sub-graph from a csr graph via an
 uniform probability. 
-Example::
+
+Example:
+
+   .. code:: python
 
   shape = (5, 5)
   data_np = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64)
@@ -850,7 +853,10 @@ static void CSRNeighborNonUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs
 NNVM_REGISTER_OP(_contrib_dgl_csr_neighbor_non_uniform_sample)
 .describe(R"code(This operator samples sub-graph from a csr graph via an
 uniform probability. 
-Example::
+
+Example:
+
+   .. code:: python
 
   shape = (5, 5)
   prob = mx.nd.array([0.9, 0.8, 0.2, 0.4, 0.1], dtype=np.float32)
@@ -1379,6 +1385,8 @@ the data value of float32.
 
 Example:
 
+   .. code:: python
+
   x = [[ 1, 0, 0 ],
        [ 0, 2, 0 ],
        [ 0, 0, 3 ]]
@@ -1400,5 +1408,215 @@ the data value of float32.
 .set_attr<FComputeEx>("FComputeEx<cpu>", DGLAdjacencyForwardEx<cpu>)
 .add_argument("data", "NDArray-or-Symbol", "Input ndarray");
 
+///////////////////////// Compact subgraphs ///////////////////////////
+
+struct SubgraphCompactParam : public dmlc::Parameter<SubgraphCompactParam> {
+  int num_args;
+  bool return_mapping;
+  nnvm::Tuple<nnvm::dim_t> graph_sizes;
+  DMLC_DECLARE_PARAMETER(SubgraphCompactParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(2)
+    .describe("Number of input arguments.");
+    DMLC_DECLARE_FIELD(return_mapping)
+    .describe("Return mapping of vid and eid between the subgraph and the parent graph.");
+    DMLC_DECLARE_FIELD(graph_sizes)
+    .describe("the number of vertices in each graph.");
+  }
+};  // struct SubgraphCompactParam
+
+DMLC_REGISTER_PARAMETER(SubgraphCompactParam);
+
+static inline size_t get_num_graphs(const SubgraphCompactParam &params) {
+  // Each CSR needs a 1D array to store the original vertex Id for each row.
+  return params.num_args / 2;
+}
+
+static void CompactSubgraph(const NDArray &csr, const NDArray &vids,
+                            const NDArray &out_csr, size_t graph_size) {
+  TBlob in_idx_data = csr.aux_data(csr::kIdx);
+  TBlob in_ptr_data = csr.aux_data(csr::kIndPtr);
+  const dgl_id_t *indices_in = in_idx_data.dptr<dgl_id_t>();
+  const dgl_id_t *indptr_in = in_ptr_data.dptr<dgl_id_t>();
+  const dgl_id_t *row_ids = vids.data().dptr<dgl_id_t>();
+  size_t num_elems = csr.aux_data(csr::kIdx).shape_.Size();
+  // The last element in vids is the actual number of vertices in the subgraph.
+  CHECK_EQ(vids.shape()[0], in_ptr_data.shape_[0]);
+  CHECK_EQ(static_cast<size_t>(row_ids[vids.shape()[0] - 1]), graph_size);
+
+  // Prepare the Id map from the original graph to the subgraph.
+  std::unordered_map<dgl_id_t, dgl_id_t> id_map;
+  id_map.reserve(graph_size);
+  for (size_t i = 0; i < graph_size; i++) {
+    id_map.insert(std::pair<dgl_id_t, dgl_id_t>(row_ids[i], i));
+    CHECK_NE(row_ids[i], -1);
+  }
+
+  TShape nz_shape(1);
+  nz_shape[0] = num_elems;
+  TShape indptr_shape(1);
+  CHECK_EQ(out_csr.shape()[0], graph_size);
+  indptr_shape[0] = graph_size + 1;
+  CHECK_GE(in_ptr_data.shape_[0], indptr_shape[0]);
+
+  out_csr.CheckAndAllocData(nz_shape);
+  out_csr.CheckAndAllocAuxData(csr::kIdx, nz_shape);
+  out_csr.CheckAndAllocAuxData(csr::kIndPtr, indptr_shape);
+
+  dgl_id_t *indices_out = out_csr.aux_data(csr::kIdx).dptr<dgl_id_t>();
+  dgl_id_t *indptr_out = out_csr.aux_data(csr::kIndPtr).dptr<dgl_id_t>();
+  dgl_id_t *sub_eids = out_csr.data().dptr<dgl_id_t>();
+  std::copy(indptr_in, indptr_in + indptr_shape[0], indptr_out);
+  for (int64_t i = 0; i < nz_shape[0]; i++) {
+    dgl_id_t old_id = indices_in[i];
+    auto it = id_map.find(old_id);
+    CHECK(it != id_map.end());
+    indices_out[i] = it->second;
+    sub_eids[i] = i;
+  }
+}
+
+static void SubgraphCompactComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                        const OpContext& ctx,
+                                        const std::vector<NDArray>& inputs,
+                                        const std::vector<OpReqType>& req,
+                                        const std::vector<NDArray>& outputs) {
+  const SubgraphCompactParam& params = nnvm::get<SubgraphCompactParam>(attrs.parsed);
+  int num_g = get_num_graphs(params);
+#pragma omp parallel for
+  for (int i = 0; i < num_g; i++) {
+    CompactSubgraph(inputs[i], inputs[i + num_g], outputs[i], params.graph_sizes[i]);
+  }
+}
+
+static bool SubgraphCompactStorageType(const nnvm::NodeAttrs& attrs,
+                                       const int dev_mask,
+                                       DispatchMode* dispatch_mode,
+                                       std::vector<int> *in_attrs,
+                                       std::vector<int> *out_attrs) {
+  const SubgraphCompactParam& params = nnvm::get<SubgraphCompactParam>(attrs.parsed);
+  size_t num_g = get_num_graphs(params);
+  CHECK_EQ(num_g * 2, in_attrs->size());
+  // These are the input subgraphs.
+  for (size_t i = 0; i < num_g; i++)
+    CHECK_EQ(in_attrs->at(i), kCSRStorage);
+  // These are the vertex Ids in the original graph.
+  for (size_t i = 0; i < num_g; i++)
+    CHECK_EQ(in_attrs->at(i + num_g), kDefaultStorage);
+
+  bool success = true;
+  *dispatch_mode = DispatchMode::kFComputeEx;
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    if (!type_assign(&(*out_attrs)[i], mxnet::kCSRStorage))
+      success = false;
+  }
+  return success;
+}
+
+static bool SubgraphCompactShape(const nnvm::NodeAttrs& attrs,
+                                 std::vector<TShape> *in_attrs,
+                                 std::vector<TShape> *out_attrs) {
+  const SubgraphCompactParam& params = nnvm::get<SubgraphCompactParam>(attrs.parsed);
+  size_t num_g = get_num_graphs(params);
+  CHECK_EQ(num_g * 2, in_attrs->size());
+  // These are the input subgraphs.
+  for (size_t i = 0; i < num_g; i++) {
+    CHECK_EQ(in_attrs->at(i).ndim(), 2U);
+    CHECK_GE(in_attrs->at(i)[0], params.graph_sizes[i]);
+    CHECK_GE(in_attrs->at(i)[1], params.graph_sizes[i]);
+  }
+  // These are the vertex Ids in the original graph.
+  for (size_t i = 0; i < num_g; i++) {
+    CHECK_EQ(in_attrs->at(i + num_g).ndim(), 1U);
+    CHECK_GE(in_attrs->at(i + num_g)[0], params.graph_sizes[i]);
+  }
+
+  for (size_t i = 0; i < num_g; i++) {
+    TShape gshape(2);
+    gshape[0] = params.graph_sizes[i];
+    gshape[1] = params.graph_sizes[i];
+    out_attrs->at(i) = gshape;
+    if (params.return_mapping)
+      out_attrs->at(i + num_g) = gshape;
+  }
+  return true;
+}
+
+static bool SubgraphCompactType(const nnvm::NodeAttrs& attrs,
+                                std::vector<int> *in_attrs,
+                                std::vector<int> *out_attrs) {
+  for (size_t i = 0; i < in_attrs->size(); i++) {
+    CHECK_EQ(in_attrs->at(i), mshadow::kInt64);
+  }
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    out_attrs->at(i) = mshadow::kInt64;
+  }
+  return true;
+}
+
+NNVM_REGISTER_OP(_contrib_dgl_graph_compact)
+.describe(R"code(This operator compacts a CSR matrix generated by
+dgl_csr_neighbor_uniform_sample and dgl_csr_neighbor_non_uniform_sample.
+The CSR matrices generated by these two operators may have many empty
+rows at the end and many empty columns. This operator removes these
+empty rows and empty columns.
+
+Example:
+
+   .. code:: python
+
+  shape = (5, 5)
+  data_np = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64)
+  indices_np = np.array([1,2,3,4,0,2,3,4,0,1,3,4,0,1,2,4,0,1,2,3], dtype=np.int64)
+  indptr_np = np.array([0,4,8,12,16,20], dtype=np.int64)
+  a = mx.nd.sparse.csr_matrix((data_np, indices_np, indptr_np), shape=shape)
+  seed = mx.nd.array([0,1,2,3,4], dtype=np.int64)
+  out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1,
+          num_neighbor=2, max_num_vertices=6)
+  subg_v = out[0]
+  subg = out[1]
+  compact = mx.nd.contrib.dgl_graph_compact(subg, subg_v,
+          graph_sizes=(subg_v[-1].asnumpy()[0]), return_mapping=False)
+
+  compact.asnumpy()
+  array([[0, 0, 0, 1, 0],
+         [2, 0, 3, 0, 0],
+         [0, 4, 0, 0, 5],
+         [0, 6, 0, 0, 7],
+         [8, 9, 0, 0, 0]])
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<SubgraphCompactParam>)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const SubgraphCompactParam& params = nnvm::get<SubgraphCompactParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const SubgraphCompactParam& params = nnvm::get<SubgraphCompactParam>(attrs.parsed);
+  int num_varray = get_num_graphs(params);
+  if (params.return_mapping)
+    return num_varray * 2;
+  else
+    return num_varray;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  const SubgraphCompactParam& params = nnvm::get<SubgraphCompactParam>(attrs.parsed);
+  std::vector<std::string> names;
+  names.reserve(params.num_args);
+  size_t num_graphs = get_num_graphs(params);
+  for (size_t i = 0; i < num_graphs; i++)
+    names.push_back("graph" + std::to_string(i));
+  for (size_t i = 0; i < num_graphs; ++i)
+    names.push_back("varray" + std::to_string(i));
+  return names;
+})
+.set_attr<FInferStorageType>("FInferStorageType", SubgraphCompactStorageType)
+.set_attr<nnvm::FInferShape>("FInferShape", SubgraphCompactShape)
+.set_attr<nnvm::FInferType>("FInferType", SubgraphCompactType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", SubgraphCompactComputeExCPU)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("graph_data", "NDArray-or-Symbol[]", "Input graphs and input vertex Ids.")
+.add_arguments(SubgraphCompactParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_dgl_graph.py b/tests/python/unittest/test_dgl_graph.py
index f996d7f38de..069fef6e32f 100644
--- a/tests/python/unittest/test_dgl_graph.py
+++ b/tests/python/unittest/test_dgl_graph.py
@@ -63,6 +63,18 @@ def check_non_uniform(out, num_hops, max_num_vertices):
     for data in layer:
         assert(data <= num_hops)
 
+def check_compact(csr, id_arr, num_nodes):
+    compact = mx.nd.contrib.dgl_graph_compact(csr, id_arr, graph_sizes=num_nodes, return_mapping=False)
+    assert compact.shape[0] == num_nodes
+    assert compact.shape[1] == num_nodes
+    assert mx.nd.sum(compact.indptr == csr.indptr[0:(num_nodes + 1)]).asnumpy() == num_nodes + 1
+    sub_indices = compact.indices.asnumpy()
+    indices = csr.indices.asnumpy()
+    id_arr = id_arr.asnumpy()
+    for i in range(len(sub_indices)):
+        sub_id = sub_indices[i]
+        assert id_arr[sub_id] == indices[i]
+
 def test_uniform_sample():
     shape = (5, 5)
     data_np = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64)
@@ -74,36 +86,64 @@ def test_uniform_sample():
     out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
     assert (len(out) == 3)
     check_uniform(out, num_hops=1, max_num_vertices=5)
+    num_nodes = out[0][-1].asnumpy()
+    assert num_nodes > 0
+    assert num_nodes < len(out[0])
+    check_compact(out[1], out[0], num_nodes)
 
     seed = mx.nd.array([0], dtype=np.int64)
     out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=1, max_num_vertices=4)
     assert (len(out) == 3)
     check_uniform(out, num_hops=1, max_num_vertices=4)
+    num_nodes = out[0][-1].asnumpy()
+    assert num_nodes > 0
+    assert num_nodes < len(out[0])
+    check_compact(out[1], out[0], num_nodes)
 
     seed = mx.nd.array([0], dtype=np.int64)
     out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=2, num_neighbor=1, max_num_vertices=4)
     assert (len(out) == 3)
     check_uniform(out, num_hops=2, max_num_vertices=4)
+    num_nodes = out[0][-1].asnumpy()
+    assert num_nodes > 0
+    assert num_nodes < len(out[0])
+    check_compact(out[1], out[0], num_nodes)
 
     seed = mx.nd.array([0,2,4], dtype=np.int64)
     out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
     assert (len(out) == 3)
     check_uniform(out, num_hops=1, max_num_vertices=5)
+    num_nodes = out[0][-1].asnumpy()
+    assert num_nodes > 0
+    assert num_nodes < len(out[0])
+    check_compact(out[1], out[0], num_nodes)
 
     seed = mx.nd.array([0,4], dtype=np.int64)
     out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
     assert (len(out) == 3)
     check_uniform(out, num_hops=1, max_num_vertices=5)
+    num_nodes = out[0][-1].asnumpy()
+    assert num_nodes > 0
+    assert num_nodes < len(out[0])
+    check_compact(out[1], out[0], num_nodes)
 
     seed = mx.nd.array([0,4], dtype=np.int64)
     out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=2, num_neighbor=2, max_num_vertices=5)
     assert (len(out) == 3)
     check_uniform(out, num_hops=2, max_num_vertices=5)
+    num_nodes = out[0][-1].asnumpy()
+    assert num_nodes > 0
+    assert num_nodes < len(out[0])
+    check_compact(out[1], out[0], num_nodes)
 
     seed = mx.nd.array([0,4], dtype=np.int64)
     out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
     assert (len(out) == 3)
     check_uniform(out, num_hops=1, max_num_vertices=5)
+    num_nodes = out[0][-1].asnumpy()
+    assert num_nodes > 0
+    assert num_nodes < len(out[0])
+    check_compact(out[1], out[0], num_nodes)
 
 def test_non_uniform_sample():
     shape = (5, 5)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services