You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/06/21 18:40:39 UTC
[incubator-mxnet] branch subgraph updated: Graph partitioner and
subgraph op (#11251)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/subgraph by this push:
new 9567bd7 Graph partitioner and subgraph op (#11251)
9567bd7 is described below
commit 9567bd7bcb70bf0ccfdbb24a7212d723200e18bd
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu Jun 21 11:40:31 2018 -0700
Graph partitioner and subgraph op (#11251)
Graph partitioner and subgraph op
---
example/subgraph_op/common | 1 +
example/subgraph_op/imagenet_inference.py | 166 +++++++
include/mxnet/c_api.h | 5 +
include/mxnet/engine.h | 22 +-
include/mxnet/ndarray.h | 4 +
include/mxnet/op_attr_types.h | 7 +-
src/c_api/c_api_symbolic.cc | 25 +
src/engine/engine_impl.h | 10 +
src/engine/naive_engine.cc | 31 +-
src/engine/threaded_engine.cc | 10 +-
src/engine/threaded_engine.h | 1 +
src/executor/graph_executor.cc | 3 +
src/imperative/imperative_utils.h | 21 +-
src/ndarray/ndarray.cc | 2 +
src/operator/subgraph/common.h | 270 +++++++++++
src/operator/subgraph/default_subgraph_op.cc | 113 +++++
src/operator/subgraph/default_subgraph_op.cu | 41 ++
src/operator/subgraph/default_subgraph_op.h | 127 +++++
src/operator/subgraph/partition_graph.cc | 687 +++++++++++++++++++++++++++
tests/python/gpu/test_operator_gpu.py | 1 +
tests/python/unittest/test_gluon_trainer.py | 1 +
tests/python/unittest/test_subgraph_op.py | 135 ++++++
22 files changed, 1672 insertions(+), 11 deletions(-)
diff --git a/example/subgraph_op/common b/example/subgraph_op/common
new file mode 120000
index 0000000..cafb914
--- /dev/null
+++ b/example/subgraph_op/common
@@ -0,0 +1 @@
+../image-classification/common
\ No newline at end of file
diff --git a/example/subgraph_op/imagenet_inference.py b/example/subgraph_op/imagenet_inference.py
new file mode 100644
index 0000000..8a38cff
--- /dev/null
+++ b/example/subgraph_op/imagenet_inference.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import time
+import mxnet as mx
+from common import modelzoo
+from mxnet import nd
+from mxnet.contrib.quantization import *
+from mxnet.base import _LIB
+
+
+def download_dataset(dataset_url, dataset_dir, logger=None):
+ if logger is not None:
+ logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir))
+ mx.test_utils.download(dataset_url, dataset_dir)
+
+
+def download_model(model_name, logger=None):
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ model_path = os.path.join(dir_path, 'model')
+ if logger is not None:
+ logger.info('Downloading model %s... into path %s' % (model_name, model_path))
+ return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))
+
+
+def advance_data_iter(data_iter, n):
+ assert n >= 0
+ if n == 0:
+ return data_iter
+ has_next_batch = True
+ while has_next_batch:
+ try:
+ data_iter.next()
+ n -= 1
+ if n == 0:
+ return data_iter
+ except StopIteration:
+ has_next_batch = False
+
+
+def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger=None):
+ metrics = [mx.metric.create('acc'),
+ mx.metric.create('top_k_accuracy', top_k=5)]
+ if not isinstance(metrics, list):
+ metrics = [metrics, ]
+ mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ])
+ mod.bind(for_training=False,
+ data_shapes=data.provide_data,
+ label_shapes=data.provide_label)
+ mod.set_params(arg_params, aux_params)
+
+ tic = time.time()
+ num = 0
+ for batch in data:
+ mod.forward(batch, is_train=False)
+ for m in metrics:
+ mod.update_metric(m, batch.label)
+ num += batch_size
+ if max_num_examples is not None and num >= max_num_examples:
+ break
+
+ speed = num / (time.time() - tic)
+
+ if logger is not None:
+ logger.info('Finished inference with %d images' % num)
+ logger.info('Finished with %f images per second', speed)
+ for m in metrics:
+ logger.info(m.get())
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Score a model on a dataset')
+ parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
+ help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
+ parser.add_argument('--batch-size', type=int, default=32)
+ parser.add_argument('--label-name', type=str, default='softmax_label')
+ parser.add_argument('--dataset', type=str, required=True, help='dataset path')
+ parser.add_argument('--rgb-mean', type=str, default='0,0,0')
+ parser.add_argument('--image-shape', type=str, default='3,224,224')
+ parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding')
+ parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference')
+ parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference')
+ parser.add_argument('--shuffle-dataset', action='store_true', default=True,
+ help='shuffle the calibration dataset')
+ parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
+ help='shuffling chunk seed, see'
+ ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+ ' for more details')
+ parser.add_argument('--shuffle-seed', type=int, default=48564309,
+ help='shuffling seed, see'
+ ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+ ' for more details')
+
+ args = parser.parse_args()
+
+ logging.basicConfig()
+ logger = logging.getLogger('logger')
+ logger.setLevel(logging.INFO)
+ data_nthreads = args.data_nthreads
+ batch_size = args.batch_size
+ logger.info('batch size = %d for inference' % batch_size)
+
+ rgb_mean = args.rgb_mean
+ logger.info('rgb_mean = %s' % rgb_mean)
+ rgb_mean = [float(i) for i in rgb_mean.split(',')]
+ mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
+
+ label_name = args.label_name
+ logger.info('label_name = %s' % label_name)
+
+ image_shape = args.image_shape
+ data_shape = tuple([int(i) for i in image_shape.split(',')])
+ logger.info('Input data shape = %s' % str(data_shape))
+
+ dataset = args.dataset
+ download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
+ logger.info('Dataset for inference: %s' % dataset)
+
+ # creating data iterator
+ data = mx.io.ImageRecordIter(path_imgrec=dataset,
+ label_width=1,
+ preprocess_threads=data_nthreads,
+ batch_size=batch_size,
+ data_shape=data_shape,
+ label_name=label_name,
+ rand_crop=False,
+ rand_mirror=False,
+ shuffle=True,
+ shuffle_chunk_seed=3982304,
+ seed=48564309,
+ **mean_args)
+
+ # download model
+ prefix, epoch = download_model(model_name=args.model, logger=logger)
+ sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+ op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
+ out = SymbolHandle()
+ check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names),
+ ctypes.byref(out)))
+ psym = Symbol(out)
+
+ # make sure that fp32 inference works on the same images as calibrated quantized model
+ logger.info('Skipping the first %d batches' % args.num_skipped_batches)
+ data = advance_data_iter(data, args.num_skipped_batches)
+
+ num_inference_images = args.num_inference_batches * batch_size
+ logger.info('Running model %s for inference' % args.model)
+ score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
+ max_num_examples=num_inference_images, logger=logger)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 4dd858a..8a714a9 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1457,6 +1457,11 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
const float* high_quantiles,
SymbolHandle* ret_sym_handle);
+MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
+ const mx_uint num_ops,
+ const char** op_names,
+ SymbolHandle* ret_sym_handle);
+
//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index fd1fe89..2424a67 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -41,8 +41,26 @@ class Engine;
/*! \brief namespace of engine internal types. */
namespace engine {
-/*! \brief Internal representation of variable. */
-struct Var;
+/*! \brief base class of engine variables.*/
+struct Var {
+ virtual uint32_t version() {
+ return version_;
+ }
+ virtual ~Var() = default;
+ /*!
+ * \brief cast variable to derived type T
+ * \tparam T the type we want to cast into.
+ * \return A casted variable.
+ */
+ template <typename T>
+ inline T* Cast();
+ /*!
+ * \brief version number of the var. Every time the object it is associated with
+ * is modified, the version number is incremented by 1.
+ */
+ uint32_t version_{0};
+}; // struct Var
+
/*! \brief Internal representation of operator. */
struct Opr;
/*! \brief Variable pointer type, usually hold by user used to specify dependencies. */
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index faffe1b..f73a6ed 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -338,6 +338,10 @@ class NDArray {
inline size_t byte_offset() const {
return byte_offset_;
}
+ /*! \brief return var version of the NDArray*/
+ inline uint32_t version() const {
+ return var()->version();
+ }
/*!
* \brief save the content into binary stream
* \param strm the output stream
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index f4694ef..ebe8249 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -98,7 +98,12 @@ enum class ExecType {
* In current implementation, copy operator is specially handled by executor.
* This flag is used for special case treatment and future extension of different copy ops.
*/
- kCrossDeviceCopy
+ kCrossDeviceCopy,
+ /*!
+ * A subgraph execution should happen in the main thread, instead of
+ * in the execution engine.
+ */
+ kSubgraphExec,
};
/*! \brief the dispatch mode of the operator */
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index e5e9b52..2f8c4f5 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -31,6 +31,7 @@
#include "./c_api_common.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
+#include "../operator/subgraph/default_subgraph_op.h"
namespace mxnet {
namespace op {
@@ -625,3 +626,27 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
*ret_qsym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
+
+int MXPartitionGraph(SymbolHandle sym_handle,
+ const mx_uint num_ops,
+ const char** op_names,
+ SymbolHandle* ret_sym_handle) {
+ nnvm::Symbol* s = new nnvm::Symbol();
+ API_BEGIN();
+ std::unordered_set<std::string> op_name_set;
+ for (size_t i = 0; i < num_ops; ++i) {
+ op_name_set.emplace(op_names[i]);
+ }
+ nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+ *s = sym->Copy();
+ nnvm::Graph g = Symbol2Graph(*s);
+ if (!op_name_set.empty()) {
+ mxnet::op::SubgraphPropertyPtr property
+ = std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
+ g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
+ }
+ g = ApplyPass(std::move(g), "PartitionGraph");
+ s->outputs = g.outputs;
+ *ret_sym_handle = s;
+ API_END_HANDLE_ERROR(delete s);
+}
diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h
index b3ec34d..9219b91 100644
--- a/src/engine/engine_impl.h
+++ b/src/engine/engine_impl.h
@@ -33,8 +33,12 @@
namespace mxnet {
namespace engine {
+#if 0
/*! \brief base class of engine variables, used for type checking */
struct Var {
+ virtual uint32_t version() {
+ return version_;
+ }
#if ENGINE_DEBUG
virtual ~Var() = default;
#endif // ENGINE_DEBUG
@@ -45,7 +49,13 @@ struct Var {
*/
template <typename T>
inline T* Cast();
+ /*!
+ * \brief version number of the var. Every time the object it is associated with
+ * is modified, the version number is incremented by 1.
+ */
+ uint32_t version_{0};
}; // struct Var
+#endif
/*! \brief base class of engine operators, used for type checking */
struct Opr {
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 8196af2..e0a47fa 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -28,10 +28,24 @@
#include "./engine_impl.h"
#include "../profiler/profiler.h"
#include "./openmp.h"
+#include "../common/object_pool.h"
namespace mxnet {
namespace engine {
+/*!
+ * \brief var used in Naive Engine for tracking the version
+ * of the objects it is associated with.
+ */
+class NaiveVar final
+ : public Var, public common::ObjectPoolAllocatable<NaiveVar> {
+ public:
+ inline static NaiveVar* CastFromBase(Var* ptr) {
+ return ptr->Cast<NaiveVar>();
+ }
+}; // class NaiveVar
+
+
// implement naive engine
class NaiveEngine final : public Engine {
public:
@@ -71,8 +85,11 @@ class NaiveEngine final : public Engine {
// new variables
VarHandle NewVariable() override {
+ return NaiveVar::New();
+#if 0
size_t v = ++counter_;
return reinterpret_cast<VarHandle>(v);
+#endif
}
OprHandle NewOperator(AsyncFn fn,
@@ -165,14 +182,26 @@ class NaiveEngine final : public Engine {
}
CHECK(this->req_completed_)
<< "NaiveEngine only support synchronize Push so far";
+ // increment var version
+ for (auto var : mutable_vars) {
+ ++var->version_;
+ }
if (profiling) {
opr->opr_profile->stop();
}
}
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override {
+ NaiveVar* naive_var = NaiveVar::CastFromBase(var);
+ this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable {
+ delete_fn(ctx);
+ NaiveVar::Delete(naive_var);
+ on_complete();
+ }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable");
+#if 0
this->PushSync(delete_fn, exec_ctx, {}, {var},
FnProperty::kNormal, 0, "DeleteVariable");
+#endif
}
void WaitForVar(VarHandle var) override {
@@ -192,8 +221,6 @@ class NaiveEngine final : public Engine {
}
// whether action is completed
bool req_completed_;
- // counter
- std::atomic<size_t> counter_{0};
/*! \brief whether it is during shutdown phase*/
std::atomic<bool> shutdown_phase_{false};
// CPU stream
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc19..bd11697 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -130,6 +130,9 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
assert(pending_write_ != nullptr);
CHECK_EQ(num_pending_reads_, kWriteTriggered);
+ // increment version number
+ ++version_;
+
// really delete
if (to_delete_) {
VersionedVarBlock *head = pending_write_->next;
@@ -164,7 +167,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
}
// This is outside of lock scope
// Be very carful, pending_write_ and num_pending_reads_
- // can change now, do not reply ont the two variables.
+ // can change now, do not rely on these two variables.
// The linked list \in [old_pending_write, end_of_read_chain)
// is already detached from this Var.
// So it is safe to modify these
@@ -196,6 +199,11 @@ inline bool ThreadedVar::ready_to_read() {
return this->is_ready_to_read();
}
+inline uint32_t ThreadedVar::version() {
+ std::lock_guard<std::mutex> lock{mutex_};
+ return this->version_;
+}
+
// implementation of threaded engine
ThreadedVar* ThreadedEngine::NewVariable() {
return ThreadedVar::New(VersionedVarBlock::New());
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 428f0d8..7730c06 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,6 +162,7 @@ class ThreadedVar final
inline void SetToDelete();
/*! \return whether this variable is ready to read. */
inline bool ready_to_read();
+ inline uint32_t version() override;
/*!
* \brief Cast a Var pointer to ThreadedVar pointer
* \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f9..ae05fe4 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1614,6 +1614,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
CHECK_EQ(opnode.exec->in_array.size(), 1U);
CHECK_EQ(opnode.exec->out_array.size(), 1U);
CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
+ } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) {
+ // If the node contains a subgraph, we can't execute it in the engine.
+ opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
} else if (opnode.cached_opr != nullptr) {
bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 726531d..08ea05a 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -434,7 +434,8 @@ inline void PushFComputeEx(const FComputeEx& fn,
}
};
- if (exec_type == ExecType::kCrossDeviceCopy) {
+ if (exec_type == ExecType::kCrossDeviceCopy
+ || exec_type == ExecType::kSubgraphExec) {
run(RunContext{ctx, nullptr});
} else {
CHECK(exec_type == ExecType::kSync);
@@ -475,12 +476,18 @@ inline void PushOperator(const OpStatePtr& state,
InvalidateOutputs(outputs, req);
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
- if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+ if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
+ && rctx.get_stream<gpu>()) {
rctx.get_stream<gpu>()->Wait();
}
};
- if (exec_type == ExecType::kSync) {
+ // For operators with subgraphs, we need to invoke them in the main thread
+ // instead of the threaded engine.
+ if (exec_type == ExecType::kSubgraphExec) {
+ RunContext rctx{ctx, nullptr};
+ run(rctx, engine::CallbackOnComplete());
+ } else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
[=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
ctx, read_vars, write_vars, FnProperty::kNormal, 0,
@@ -519,12 +526,16 @@ inline void PushOperator(const OpStatePtr& state,
fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
// post-fcompute fallback, cast to original storage type, if necessary
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
- if (is_gpu && exec_type == ExecType::kSync) {
+ if (is_gpu && exec_type == ExecType::kSync
+ && rctx.get_stream<gpu>()) {
rctx.get_stream<gpu>()->Wait();
}
};
- if (exec_type == ExecType::kSync) {
+ if (exec_type == ExecType::kSubgraphExec) {
+ RunContext rctx{ctx, nullptr};
+ run(rctx, engine::CallbackOnComplete());
+ } else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
[=](RunContext rctx) {
run(rctx, engine::CallbackOnComplete());
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 94d3d90..583e2bf 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -39,6 +39,7 @@
#include "../operator/tensor/matrix_op-inl.h"
#include "../operator/tensor/init_op.h"
#include "../operator/nn/mkldnn/mkldnn_base-inl.h"
+#include "../engine/engine_impl.h"
#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
@@ -2041,6 +2042,7 @@ void NDArray::SyncCheckFormat(const bool full_check) const {
CHECK_EQ(err, kNormalErr) << "Check the validity of this sparse NDArray";
}
+
#if MXNET_PREDICT_ONLY == 0
// register API function
// those with underscore will be registered at NDArray
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
new file mode 100644
index 0000000..472312d
--- /dev/null
+++ b/src/operator/subgraph/common.h
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+
+#include <string>
+#include <set>
+#include <vector>
+#include "../elemwise_op_common.h"
+#include "../../executor/exec_pass.h"
+
+namespace mxnet {
+namespace op {
+namespace sg {
+
+struct SimpleNode;
+using SimpleNodePtr = std::shared_ptr<SimpleNode>;
+
+/*!
+ * \brief Node of the undirected graph which replicates the network structures
+ * of the computational graph. It is used to ease the graph traversal for finding
+ * subgraphs.
+ */
+struct SimpleNode {
+ static SimpleNodePtr Create() {
+ return std::make_shared<SimpleNode>();
+ }
+ SimpleNode() : label(-1), node(nullptr) {}
+ /*! subgraph label */
+ int label;
+ /*! the original node in the computational graph it references*/
+ nnvm::Node* node;
+ /*!
+ * \brief output nodes of the current node
+ * key is node ptr and value is an array of indices standing for the entry indices
+ * in key->inputs whose source is the current node.
+ */
+ std::unordered_map<nnvm::Node*, std::vector<size_t>> outputs;
+}; // struct SimpleNode
+} // namespace sg
+
+inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ return sym.ListInputNames(nnvm::Symbol::kAll).size();
+}
+
+inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ return sym.ListOutputNames().size();
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListInputs(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ return sym.ListInputNames(nnvm::Symbol::kAll);
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ return sym.ListOutputNames();
+}
+
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_shapes,
+ std::vector<TShape> *out_shapes) {
+ using namespace exec;
+ const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ nnvm::Graph g;
+ g.outputs = subgraph_sym.outputs;
+ const auto& idx_g = g.indexed_graph();
+ CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
+ CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
+
+ // Put the input and output shapes to the shape vector.
+ nnvm::ShapeVector shapes(idx_g.num_node_entries());
+ const auto &input_nids = idx_g.input_nodes();
+ CHECK_EQ(input_nids.size(), in_shapes->size());
+ for (size_t i = 0; i < in_shapes->size(); i++) {
+ auto eid = idx_g.entry_id(input_nids[i], 0);
+ shapes[eid] = in_shapes->at(i);
+ }
+ CHECK_EQ(g.outputs.size(), out_shapes->size());
+ for (size_t i = 0; i < out_shapes->size(); i++) {
+ auto eid = idx_g.entry_id(g.outputs[i]);
+ shapes[eid] = out_shapes->at(i);
+ }
+
+ // Infer shape of the graph.
+ g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+ g = exec::InferShape(std::move(g));
+
+ // Copy the inferred shape back to the input shapes and the output shapes.
+ shapes = g.GetAttr<nnvm::ShapeVector>("shape");
+ // assign to in_shapes
+ for (size_t i = 0; i < in_shapes->size(); ++i) {
+ const auto eid = idx_g.entry_id(input_nids[i], 0);
+ SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]);
+ }
+ // assign to out_shapes
+ for (size_t i = 0; i < g.outputs.size(); ++i) {
+ const auto eid = idx_g.entry_id(g.outputs[i]);
+ SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
+ }
+ // Check if we have inferred the shapes correctly.
+ return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
+ const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ nnvm::Graph g;
+ g.outputs = subgraph_sym.outputs;
+ const auto& idx_g = g.indexed_graph();
+ CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
+ CHECK_EQ(idx_g.outputs().size(), out_types->size());
+
+ // Put the input and output data types to the dtype vector.
+ nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
+ const auto &input_nids = idx_g.input_nodes();
+ CHECK_EQ(input_nids.size(), in_types->size());
+ for (size_t i = 0; i < in_types->size(); i++) {
+ auto eid = idx_g.entry_id(input_nids[i], 0);
+ types[eid] = in_types->at(i);
+ }
+ CHECK_EQ(g.outputs.size(), out_types->size());
+ for (size_t i = 0; i < out_types->size(); i++) {
+ auto eid = idx_g.entry_id(g.outputs[i]);
+ types[eid] = out_types->at(i);
+ }
+
+ // Infer data type of the graph.
+ g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
+ g = exec::InferType(std::move(g));
+
+ types = g.GetAttr<nnvm::DTypeVector>("dtype");
+ // assign to in_types
+ for (size_t i = 0; i < in_types->size(); ++i) {
+ const auto eid = idx_g.entry_id(input_nids[i], 0);
+ TYPE_ASSIGN_CHECK(*in_types, i, types[eid]);
+ }
+ // assign to out_types
+ for (size_t i = 0; i < g.outputs.size(); ++i) {
+ const auto eid = idx_g.entry_id(g.outputs[i]);
+ TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
+ }
+ // Check if we have inferred the dtypes correctly.
+ return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_stypes,
+ std::vector<int>* out_stypes) {
+ const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ nnvm::Graph g;
+ g.outputs = subgraph_sym.outputs;
+ const auto& idx_g = g.indexed_graph();
+ CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
+ CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
+ exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
+
+ // Put the input and output storages to the storage vector.
+ StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
+ const auto &input_nids = idx_g.input_nodes();
+ CHECK_EQ(input_nids.size(), in_stypes->size());
+ for (size_t i = 0; i < in_stypes->size(); i++) {
+ auto eid = idx_g.entry_id(input_nids[i], 0);
+ stypes[eid] = in_stypes->at(i);
+ }
+ CHECK_EQ(g.outputs.size(), out_stypes->size());
+ for (size_t i = 0; i < out_stypes->size(); i++) {
+ auto eid = idx_g.entry_id(g.outputs[i]);
+ stypes[eid] = out_stypes->at(i);
+ }
+
+ // Infer storage type of the graph.
+ bool dev_match = g.attrs.count("dev_mask") &&
+ g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
+ if (!dev_match) {
+ g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
+ }
+ g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
+ g = exec::InferStorageType(std::move(g));
+
+ stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ // assign to in_types
+ for (size_t i = 0; i < in_stypes->size(); ++i) {
+ const auto eid = idx_g.entry_id(input_nids[i], 0);
+ STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes[eid]);
+ }
+
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ // assign to out_types
+ for (size_t i = 0; i < g.outputs.size(); ++i) {
+ const auto eid = idx_g.entry_id(g.outputs[i]);
+ STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
+ }
+ // Check if we have inferred the storages correctly.
+ return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
+}
+
+inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
+ return ExecType::kSubgraphExec;
+}
+
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
+ const std::vector<std::string> immutable_input_names =
+ subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+ const std::vector<std::string> mutable_input_names =
+ subgraph_sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+ CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size());
+ std::vector<uint32_t> ret;
+ size_t i1 = 0, i2 = 0;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) {
+ ++i1;
+ } else {
+ CHECK(i2 < mutable_input_names.size());
+ CHECK_EQ(input_names[i], mutable_input_names[i2]);
+ ++i2;
+ ret.push_back(i);
+ }
+ }
+ return ret;
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
+ const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
+ std::set<ResourceRequest::Type> resource_types;
+ DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
+ if (!node->is_variable() && fresource.count(node->op())) {
+ for (ResourceRequest& r : fresource[node->op()](node->attrs)){
+ resource_types.insert(r.type);
+ }
+ }
+ });
+ return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
+}
+
+#if 0
+// TODO(junwu): add this attribute for visible outputs
+inline uint32_t DefaultSubgraphOpNumVisibleOutputs(const nnvm::NodeAttrs& attrs) {
+}
+#endif
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_SUBGRAPH_COMMON_H_
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
new file mode 100644
index 0000000..8372ae9
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -0,0 +1,113 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#include <mxnet/ndarray.h>
+#include "./default_subgraph_op.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
+
+namespace mxnet {
+namespace op {
+
+#define SUBGRAPH_DEBUG 1
+
+class DefaultSubgraphOperator {
+ public:
+ explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
+ subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}}));
+ }
+
+ void Forward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+ void Backward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ LOG(FATAL) << "Not implemented";
+ }
+
+ private:
+ nnvm::Symbol subgraph_sym_;
+ CachedOpPtr subgraph_exec_;
+};
+
+void DefaultSubgraphOperator::Forward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ std::vector<NDArray> tmp_inputs = inputs;
+ std::vector<NDArray*> input_ptrs;
+ input_ptrs.reserve(inputs.size());
+ for (auto& nd : tmp_inputs) {
+ input_ptrs.push_back(&nd);
+ }
+ std::vector<NDArray> tmp_outputs = outputs;
+ std::vector<NDArray*> output_ptrs;
+ for (auto& nd : tmp_outputs) {
+ output_ptrs.push_back(&nd);
+ }
+#if SUBGRAPH_DEBUG
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
+ }
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
+ }
+#endif
+ subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
+}
+
+OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
+ Context ctx,
+ const std::vector<TShape>& in_shapes,
+ const std::vector<int>& in_types) {
+ const Symbol& subgraph_sym = nnvm::get<Symbol>(attrs.parsed);
+ return OpStatePtr::Create<DefaultSubgraphOperator>(subgraph_sym);
+}
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
+ op.Forward(ctx, inputs, req, outputs);
+}
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
+.set_num_inputs(DefaultSubgraphOpNumInputs)
+.set_num_outputs(DefaultSubgraphOpNumOutputs)
+.set_attr<nnvm::FListInputNames>("FListInputNames", DefaultSubgraphOpListInputs)
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
+.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
+.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
+.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DefaultSubgraphOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
+.set_attr<FResourceRequest>("FResourceRequest", DefaultSubgraphOpResourceRequest)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
new file mode 100644
index 0000000..15a76e3
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file subgraph_op.cu
+ * \brief GPU Implementation of subgraph operations
+ */
+
+#include "./default_subgraph_op.h"
+
+namespace mxnet {
+namespace op {
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", DefaultSubgraphOpForward);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.h b/src/operator/subgraph/default_subgraph_op.h
new file mode 100644
index 0000000..7d6624e
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
+#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
+
+#include <vector>
+#include <string>
+#include "./common.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for selecting nodes in a subgraph.
+ * When a node is passed to this object, the selection criteria may be changed.
+ * We can also specify what links we should use when traversing the neighbor
+ * nodes.
+ */
+class SubgraphSelector {
+ public:
+ virtual ~SubgraphSelector() {
+ }
+ // Determine if the node should be selected for a subgraph.
+ virtual bool Select(const nnvm::Node &n) = 0;
+ // Determine if the input node should be selected for a subgraph.
+ virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+ // Determine if the output node should be selected for a subgraph.
+ virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+ // Post processes pre-selected subgraph nodes. Return a list of nodes that
+ // users want to keep in subgraph(s).
+ virtual std::vector<nnvm::Node*> Filter(nnvm::Graph* g,
+ const std::vector<nnvm::Node*>& candidates) {
+ return candidates;
+ }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+ // the criteria of selecting the subgraph nodes.
+ virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+ // create an nnvm node for a given subgraph. Here users can customize how to
+ // execute the operators in the subgraph.
+ virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+ const int subgraph_id = 0) const = 0;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+void RegisterSubgraphProperty(SubgraphPropertyPtr property);
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+ std::shared_ptr<const std::unordered_set<std::string>> op_names;
+
+ public:
+ explicit ContainOpSelector(std::shared_ptr<const std::unordered_set<std::string>> op_names) {
+ this->op_names = op_names;
+ }
+
+ virtual bool Select(const nnvm::Node &n) {
+ return !n.is_variable() && op_names->count(n.op()->name);
+ }
+
+ virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
+ return !new_node.is_variable() && op_names->count(new_node.op()->name);
+ }
+
+ virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
+ return !new_node.is_variable() && op_names->count(new_node.op()->name);
+ }
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+ explicit DefaultSubgraphProperty(const std::unordered_set<std::string> &op_names) :
+ op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {}
+ virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const {
+ nnvm::NodePtr n = nnvm::Node::Create();
+ n->attrs.op = Op::Get("_default_subgraph_op");
+ n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+ n->attrs.parsed = sym;
+ return n;
+ }
+ virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+ return std::make_shared<ContainOpSelector>(op_names_);
+ }
+
+ private:
+ std::shared_ptr<const std::unordered_set<std::string>> op_names_;
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
new file mode 100644
index 0000000..11af49a
--- /dev/null
+++ b/src/operator/subgraph/partition_graph.cc
@@ -0,0 +1,687 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+#include <queue>
+
+#include "./default_subgraph_op.h"
+#include "./common.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+// TODO(junwu): Change this to 0
+#define SUBGRAPH_DEBUG 1
+
+namespace sg { // sg stands for subgraph
+
+#if SUBGRAPH_DEBUG
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+ std::string op_names = "";
+ for (size_t i = 0; i < simple_nodes.size(); ++i) {
+ op_names += simple_nodes[i]->node->attrs.name + ' ';
+ }
+ LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+ std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+ + ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
+ LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+ for (size_t i = 0; i < entries.size(); ++i) {
+ PrintNodeEntry(*entries[i]);
+ }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+ std::vector<SimpleNodePtr>* simple_nodes) {
+ const auto& indexed_graph = g.indexed_graph();
+ simple_nodes->reserve(indexed_graph.num_nodes());
+ DFSVisit(g.outputs, [&](const NodePtr& node) {
+ SimpleNodePtr sn = SimpleNode::Create();
+ sn->node = node.get();
+ for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+ const auto& e = sn->node->inputs[i];
+ const auto input_nid = indexed_graph.node_id(e.node.get());
+ CHECK_LT(input_nid, simple_nodes->size());
+ auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+ auto it = input_node_outputs.find(sn->node);
+ if (it == input_node_outputs.end()) {
+ input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+ } else {
+ it->second.push_back(i);
+ }
+ }
+ simple_nodes->emplace_back(std::move(sn));
+ });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<nnvm::Node*>* subgraph_nodes) {
+ for (auto n : *subgraph_nodes) {
+ const auto nid = g.indexed_graph().node_id(n);
+ simple_nodes[nid]->label = -1;
+ }
+ subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+ SubgraphSelectorPtr subgraph_selector,
+ const int label,
+ const size_t snid, // simple node id, this is a seed
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<nnvm::Node*>* subgraph_nodes,
+ std::unordered_set<const nnvm::Node*>* excluded_nodes = nullptr) {
+ const auto& indexed_graph = g.indexed_graph();
+ std::queue<SimpleNode*> node_queue;
+ if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+ CHECK_EQ(simple_nodes[snid]->label, -1);
+ simple_nodes[snid]->label = label;
+ node_queue.push(simple_nodes[snid].get());
+ }
+ // key: nodes that serve as input/output nodes to the subgraph
+ // value: pair of vectors of nodes in the subgraph. The first vector contains the
+ // output nodes of the key in the subgraph, and the second vector contains the
+ // input ndoes of the key in the subgraph. If both vectors are non-empty,
+ // it means there is a loop between the subgraph and the key node.
+ // When breaking the loop, we want to start removing the node with the largest node id.
+ std::unordered_map<const nnvm::Node*,
+ std::pair<std::vector<const nnvm::Node*>,
+ std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+ while (!node_queue.empty()) {
+ SimpleNode* cur_node = node_queue.front();
+ node_queue.pop();
+ subgraph_nodes->push_back(cur_node->node);
+ // get qualified adjacent input nodes
+ for (auto& e : cur_node->node->inputs) {
+ const bool select_input = (!excluded_nodes || !excluded_nodes->count(e.node.get()))
+ && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+ if (select_input) {
+ // e.node is a subgraph node
+ const auto nid = indexed_graph.node_id(e.node.get());
+ CHECK_LT(nid, simple_nodes.size());
+ // this node has not been visited yet
+ if (simple_nodes[nid]->label == -1) {
+ simple_nodes[nid]->label = label;
+ node_queue.push(simple_nodes[nid].get());
+ }
+ } else {
+ // e.node is an input node of the subgraph
+ non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+ }
+ }
+ // get qualified output nodes
+ for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
+ const bool select_output = (!excluded_nodes || !excluded_nodes->count(it->first))
+ && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+ if (select_output) {
+ // it->first is a subgraph node
+ const auto nid = indexed_graph.node_id(it->first);
+ CHECK_LT(nid, simple_nodes.size());
+ // this node has not been visited yet
+ if (simple_nodes[nid]->label == -1) {
+ simple_nodes[nid]->label = label;
+ node_queue.push(simple_nodes[nid].get());
+ }
+ } else {
+ // it->first is an output node of the subgraph
+ non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+ }
+ }
+ }
+ auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+ return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+ };
+ // check whether there is a loop between the subgraph and its input/output nodes
+ int excluded_node_id = -1;
+ for (auto& kv : non_subgraph_node_map) {
+ auto& output_nodes = kv.second.first;
+ auto& input_nodes = kv.second.second;
+ if (!output_nodes.empty() && !input_nodes.empty()) {
+ // there is a loop between kv->first and the subgraph
+ std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+ std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+ const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+ indexed_graph.node_id(input_nodes.back()));
+ excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+ }
+ }
+ if (excluded_node_id != -1) {
+ CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+ CHECK_NE(excluded_node_id, static_cast<int>(snid))
+ << "A cycle is found in the computational graph between nodes "
+ << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+ << simple_nodes[snid]->node->attrs.name;
+ excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+ ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+ return false;
+ }
+ std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+ return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+ SubgraphSelectorPtr subgraph_selector,
+ const int label,
+ const size_t snid,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<nnvm::Node*>* subgraph_nodes) {
+ std::unordered_set<const nnvm::Node*> excluded_nodes;
+ const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+ size_t count = 0;
+ bool success = false;
+ while (!success && count < max_num_retry) {
+ success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes,
+ subgraph_nodes, &excluded_nodes);
+ if (!success) {
+ CHECK(!excluded_nodes.empty());
+ std::string excluded_node_names;
+ for (auto node : excluded_nodes) {
+ excluded_node_names += node->attrs.name + ", ";
+ }
+ LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
+ << ". Excluding nodes " << excluded_node_names << "and retrying";
+ }
+ ++count;
+ }
+ if (!success) {
+ LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
+ << simple_nodes[snid]->node->attrs.name << " without success because a loop "
+ "is always found between the subgraph and some other nodes. Will treat "
+ "seed node " << simple_nodes[snid]->node->attrs.name
+ << "as a subgraph with one node";
+ CHECK(subgraph_nodes->empty());
+ simple_nodes[snid]->label = label;
+ subgraph_nodes->push_back(simple_nodes[snid]->node);
+ }
+}
+
+/*!
+ * \brief Given a vector of nodes, group them into individual subgraphs
+ * based upon their connectivity.
+ */
+void PostProcessNodeCandidates(const nnvm::Graph& g,
+ const std::vector<nnvm::Node*>& nodes,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<std::vector<SimpleNode*>>* subgraphs,
+ size_t* subgraph_id) {
+ const auto& indexed_graph = g.indexed_graph();
+ std::unordered_set<nnvm::Node*> node_set(nodes.begin(), nodes.end());
+ auto simple_node_cmp = [&] (const SimpleNode* node1, const SimpleNode* node2) {
+ return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
+ };
+ for (auto node : nodes) {
+ if (!node_set.count(node)) {
+ // The node has been included in a subgraph
+ continue;
+ }
+ std::queue<nnvm::Node*> q;
+ q.push(node);
+ CHECK_EQ(node_set.erase(node), 1U);
+ subgraphs->emplace_back();
+ const auto nid = indexed_graph.node_id(node);
+ simple_nodes[nid]->label = *subgraph_id;
+ subgraphs->back().push_back(simple_nodes[nid].get());
+ while (!q.empty()) {
+ nnvm::Node* cur_node = q.front();
+ q.pop();
+ for (auto& e : cur_node->inputs) {
+ auto in_it = node_set.find(e.node.get());
+ if (in_it != node_set.end()) {
+ q.push(*in_it);
+ const auto in_nid = indexed_graph.node_id(*in_it);
+ simple_nodes[in_nid]->label = *subgraph_id;
+ subgraphs->back().push_back(simple_nodes[in_nid].get());
+ node_set.erase(in_it);
+ }
+ }
+ const auto cur_nid = indexed_graph.node_id(cur_node);
+ const SimpleNode* cur_snode = simple_nodes[cur_nid].get();
+ for (const auto& kv : cur_snode->outputs) {
+ const auto out_it = node_set.find(kv.first);
+ if (out_it != node_set.end()) {
+ q.push(*out_it);
+ const auto out_nid = indexed_graph.node_id(*out_it);
+ simple_nodes[out_nid]->label = *subgraph_id;
+ subgraphs->back().push_back(simple_nodes[out_nid].get());
+ node_set.erase(out_it);
+ }
+ }
+ }
+ ++(*subgraph_id);
+ std::sort(subgraphs->back().begin(), subgraphs->back().end(), simple_node_cmp);
+ }
+ CHECK(node_set.empty());
+}
+
+/*!
+ * \brief Finds subgraphs with all nodes that meet certain criteria.
+ * All nodes in a subgraph are marked with the same label.
+ */
+void FindSubgraphs(Graph* g,
+ const SubgraphProperty &subg_prop,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
+ const auto& indexed_graph = g->indexed_graph();
+ CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
+ auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+ return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+ };
+ size_t subgraph_id = 0;
+ for (size_t i = 0; i < simple_nodes.size(); ++i) {
+ nnvm::Node* node = simple_nodes[i]->node;
+ auto subgraph_selector = subg_prop.CreateSubgraphSelector();
+ if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+ // pre-select nodes that can be grouped in a subgraph
+ std::vector<nnvm::Node*> preselected_nodes;
+ PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i, simple_nodes,
+ &preselected_nodes);
+
+ // filter out unqualified pre-selected nodes
+ std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(g, preselected_nodes);
+
+ // make sure filtered_nodes is a subset of preselected_nodes
+ for (const auto n : filtered_nodes) {
+ const auto nit = std::find(preselected_nodes.begin(), preselected_nodes.end(), n);
+ CHECK(nit != preselected_nodes.end())
+ << "Node " << n->attrs.name << " is not found in the pre-selected subgraph nodes."
+ " Please make sure that no new nodes were added in your subgraph"
+ " selector's Filter function";
+ }
+
+ // make sure nodes are sorted
+ std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
+
+ // reset node labels that are not in filtered nodes
+ for (const auto n : preselected_nodes) {
+ const auto nit = std::find(filtered_nodes.begin(), filtered_nodes.end(), n);
+ if (nit == filtered_nodes.end()) {
+ simple_nodes[indexed_graph.node_id(n)]->label = -1;
+ }
+ }
+ // find out subgraphs from the filtered nodes
+ std::vector<std::vector<SimpleNode*>> subgraphs;
+ PostProcessNodeCandidates(*g, filtered_nodes, simple_nodes, &subgraphs, &subgraph_id);
+ if (!subgraphs.empty()) {
+ subgraph_nodes->insert(subgraph_nodes->end(), subgraphs.begin(), subgraphs.end());
+ }
+ }
+ }
+}
+
+/*!
+ * \brief Sorts entries according to their topological order.
+ * Note that entry ids cannot be used to sort entries.
+ * \param entry_top_order_map mapping from entry pointer to its topological position in the graph
+ * \param entries Node entries to be sorted
+ */
+void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+ std::vector<nnvm::NodeEntry*>* entries) {
+ auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
+ const auto it1 = entry_top_order_map.find(e1);
+ CHECK(it1 != entry_top_order_map.end());
+ const auto it2 = entry_top_order_map.find(e2);
+ CHECK(it2 != entry_top_order_map.end());
+ return it1->second < it2->second;
+ };
+ std::sort(entries->begin(), entries->end(), entry_cmp);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param input_entries input entries of the subgraph
+ */
+
+void FindInputEntries(const Graph& g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ const std::vector<SimpleNode*>& subgraph_nodes,
+ const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+ std::vector<nnvm::NodeEntry*>* input_entries) {
+ const auto& indexed_graph = g.indexed_graph();
+ int label = -1;
+ for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+ if (label == -1) {
+ label = subgraph_nodes[i]->label;
+ } else {
+ CHECK_EQ(subgraph_nodes[i]->label, label);
+ }
+ auto& inputs = subgraph_nodes[i]->node->inputs;
+ for (size_t j = 0; j < inputs.size(); ++j) {
+ auto& e = inputs[j];
+ if (indexed_graph.exist(e.node.get())) {
+ // e's source node is not a subgraph node
+ const auto nid = indexed_graph.node_id(e.node.get());
+ // this is a node not belonging to the subgraph
+ if (simple_nodes[nid]->label != label) {
+ input_entries->push_back(&e);
+ }
+ } else {
+ // e's source node is a subgraph node.
+ // In this case, two subgraphs are adjacent.
+ input_entries->push_back(&e);
+ }
+ }
+ }
+ SortEntries(entry_top_order_map, input_entries);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param output_entries output entries of the subgraph
+ */
+void FindOutputEntries(Graph* g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ const std::vector<SimpleNode*>& subgraph_nodes,
+ const std::unordered_map<const nnvm::NodeEntry*, size_t>&
+ entry_top_order_map,
+ std::vector<nnvm::NodeEntry*>* output_entries) {
+ if (subgraph_nodes.empty()) return;
+ const auto& indexed_graph = g->indexed_graph();
+ int label = -1;
+ for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+ if (label == -1) {
+ label = subgraph_nodes[i]->label;
+ } else {
+ CHECK_EQ(subgraph_nodes[i]->label, label);
+ }
+ for (auto it = subgraph_nodes[i]->outputs.begin();
+ it != subgraph_nodes[i]->outputs.end(); ++it) {
+ if (indexed_graph.exist(it->first)) {
+ // if the output node is a normal graph node (not a subgraph node)
+ const auto nid = indexed_graph.node_id(it->first);
+ // this is a node not belonging to the current subgraph
+ if (simple_nodes[nid]->label != label) {
+ // TODO(zhengda) I need to test this.
+ for (auto idx : it->second) {
+ auto& e = simple_nodes[nid]->node->inputs[idx];
+ output_entries->push_back(&e);
+ }
+ }
+ } else {
+ // if the output node is a subgraph node
+ // two graphs are adjacent
+ for (auto idx : it->second) {
+ output_entries->push_back(&(it->first->inputs[idx]));
+ }
+ }
+ }
+ }
+ // Check if current subgraph contains a node which is the last node
+ // of the whole graph. If so, save its corresponding entry as well.
+ for (size_t i = 0; i < g->outputs.size(); ++i) {
+ auto& entry = g->outputs[i];
+ // The entry might has been updated as an output of
+ // a subgraph node. In this case, no need
+ // to check its source for the current subgraph. Otherwise,
+ // do the following.
+ if (indexed_graph.exist(entry.node.get())) {
+ const auto nid = indexed_graph.node_id(entry.node.get());
+ if (simple_nodes[nid]->label == label) {
+ output_entries->push_back(&entry);
+ }
+ }
+ }
+ SortEntries(entry_top_order_map, output_entries);
+}
+
+/*!
+ * \brief Given a computation graph and a set of input node entries, this function cuts
+ * the node entries and creates new variable nodes as the input nodes of the
+ * subgraph. It returns the nodes that connect to the subgraph directly and
+ * the names of the new variable nodes.
+ */
+void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
+ std::vector<nnvm::NodeEntry> *orig_entries,
+ const bool skip_var = false) {
+ orig_entries->resize(input_entries.size());
+ for (size_t i = 0; i < input_entries.size(); ++i) {
+ nnvm::NodeEntry *e = input_entries[i];
+ // If the node is a variable itself, we may want to skip the node.
+ if (e->node->is_variable() && skip_var) {
+ continue;
+ }
+
+ orig_entries->at(i) = *e;
+ nnvm::Symbol sym;
+ sym.outputs.push_back(*e);
+ const auto output_names = sym.ListOutputNames();
+ CHECK_EQ(output_names.size(), 1U);
+ nnvm::NodePtr n = nnvm::CreateVariableNode(output_names[0]);
+ *e = nnvm::NodeEntry{n, 0, 0};
+ }
+}
+
+/*!
+ * \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node
+ * and keep the subgraph in the subgraph node. The input entries and output entries
+ * of the subgraph node are kept in the same order as the subgraph's.
+ */
+void CreateSubgraphNode(Graph* g,
+ const std::vector<SimpleNodePtr>& simple_nodes,
+ const std::vector<SimpleNode*>& subgraph_nodes,
+ const size_t subgraph_id,
+ std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+#if SUBGRAPH_DEBUG
+ LOG(INFO) << "Searching for input entries...";
+#endif
+ std::vector<nnvm::NodeEntry*> input_entries;
+ FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
+ std::vector<nnvm::NodeEntry> orig_input_entries;
+ // TODO(junwu): Confirm what value to pass to skip_var
+ CutGraphInputs(input_entries, &orig_input_entries, false);
+#if SUBGRAPH_DEBUG
+ PrintNodeEntries(input_entries);
+ LOG(INFO) << "Searching for output entries...";
+#endif
+ std::vector<nnvm::NodeEntry*> output_entries;
+ FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries);
+
+ // Create a subgraph for the subgraph node
+ nnvm::Symbol sym;
+ sym.outputs.resize(output_entries.size());
+ for (size_t i = 0; i < output_entries.size(); ++i) {
+ sym.outputs[i] = *output_entries[i];
+ }
+ const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
+ nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id);
+
+ // Connect the external nodes to the subgraph node.
+ for (size_t i = 0; i < output_entries.size(); ++i) {
+ *output_entries[i] = nnvm::NodeEntry{n, static_cast<uint32_t>(i), 0};
+ }
+ n->inputs = orig_input_entries;
+ const auto& indexed_graph = g->indexed_graph();
+ for (size_t i = 0; i < n->inputs.size(); ++i) {
+ auto& e = n->inputs[i];
+ // update entry_top_order_map with newly created orig_input_entries
+ auto it = entry_top_order_map->find(input_entries[i]);
+ CHECK(it != entry_top_order_map->end());
+ CHECK_EQ(entry_top_order_map->count(&e), 0U);
+ entry_top_order_map->emplace(&e, it->second);
+ // update input entries' source simple nodes' outputs map
+ nnvm::Node* node = e.node.get();
+ if (indexed_graph.exist(node)) {
+ const auto nid = indexed_graph.node_id(node);
+ SimpleNode* sn = simple_nodes[nid].get();
+ for (SimpleNode* dest_node : subgraph_nodes) {
+ sn->outputs.erase(dest_node->node);
+ }
+ sn->outputs[n.get()].push_back(i);
+ }
+ }
+#if SUBGRAPH_DEBUG
+ PrintNodeEntries(output_entries);
+#endif
+}
+
+} // namespace sg
+
+/*!
+ * \brief Sort entries of all the nodes' inputs vectors in the topological order.
+ * This is going to be used to sort input/output entries of subgraphs to keep
+ * the topological order unchanged.
+ */
+void TopSortEntries(const Graph& g,
+ std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+ CHECK(entry_top_order_map != nullptr);
+ std::unordered_set<const nnvm::Node*> visited;
+ // tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
+ std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
+ auto in_degree = [] (const nnvm::Node* node)->size_t {
+ if (!node) {
+ return 0;
+ }
+ CHECK_EQ(node->control_deps.size(), 0U);
+ return node->inputs.size();
+ };
+ for (auto& e : g.outputs) {
+ nnvm::Node* node = e.node.get();
+ if (visited.count(node) == 0U) {
+ s.emplace(node, 0U, &e);
+ visited.insert(node);
+ }
+ while (!s.empty()) {
+ auto& top = s.top();
+ if (std::get<1>(top) == in_degree(std::get<0>(top))) {
+ // The node's inputs has been exhausted.
+ entry_top_order_map->emplace(std::get<2>(top), entry_top_order_map->size());
+ s.pop();
+ } else {
+ // The node still has input entries not visited.
+ CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
+ auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
+ nnvm::Node* input_node = entry.node.get();
+ if (visited.count(input_node) == 0U) {
+ // The entry's source node has not been visited.
+ // Push the entry to the stack for marking order later.
+ s.emplace(input_node, 0U, &entry);
+ visited.insert(input_node);
+ } else {
+ // The entry's source node has been visited before.
+ // Marking order for it.
+ entry_top_order_map->emplace(&entry, entry_top_order_map->size());
+ }
+ }
+ }
+ }
+}
+
+Graph PartitionGraph(Graph&& g) {
+ if (!g.HasAttr("subgraph_property")) { // treat the whole graph as a subgraph
+ LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
+ "The original graph is returned.";
+ return g;
+ }
+ using namespace sg;
+ const SubgraphPropertyPtr& subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
+ // top sort NodeEntry of all the nodes' inputs
+ std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
+ TopSortEntries(g, &entry_top_order_map);
+
+ // Create undirected graph for ease of finding subgraphs
+ std::vector<SimpleNodePtr> simple_nodes;
+ CreateSimpleGraph(g, &simple_nodes);
+ std::vector<std::vector<SimpleNode*>> subgraph_nodes;
+ FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes);
+ for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+#if SUBGRAPH_DEBUG
+ std::set<SimpleNode*> simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end());
+ CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
+ PrintSubgraph(subgraph_nodes[i]);
+#endif
+ CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i, &entry_top_order_map);
+ }
+ return g;
+}
+
+NNVM_REGISTER_PASS(PartitionGraph)
+.describe("Partition a graph according to the user defined rules "
+ "in a derived class of SubgraphProperty")
+.set_body(PartitionGraph)
+.set_change_graph(true);
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index ed4aaa4..9f20627 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -44,6 +44,7 @@ from test_gluon_rnn import *
from test_sparse_ndarray import *
from test_sparse_operator import *
from test_ndarray import *
+from test_subgraph_op import *
set_default_context(mx.gpu(0))
del test_support_vector_machine_l1_svm
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index 1c59cea..f8833f8 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -175,6 +175,7 @@ def test_trainer_save_load():
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
+@unittest.skip("temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11353")
@with_seed()
def test_trainer_reset_kv():
def check_trainer_reset_kv(kv):
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
new file mode 100644
index 0000000..f08c42c
--- /dev/null
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+import mxnet as mx
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
+from mxnet.symbol import Symbol
+import numpy as np
+
+
+def test_subgraph():
+ def get_graph():
+ data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32)
+ data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
+ data3 = mx.sym.sin(data2)
+ conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1)
+ rets = []
+ rets.append((conv, []))
+ rets.append((conv, [mx.sym.sin.__name__]))
+ rets.append((conv, [mx.sym.Convolution.__name__]))
+ rets.append((conv, [mx.sym.sin.__name__, mx.sym.Convolution.__name__]))
+ return rets
+
+ for regular_sym, op_names in get_graph():
+ input_names = regular_sym.list_inputs()
+ shapes = regular_sym.infer_shape()
+ types = regular_sym.infer_type()
+ out = SymbolHandle()
+
+ check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)),
+ c_str_array(op_names), ctypes.byref(out)))
+ subgraph_sym = Symbol(out)
+ assert input_names == subgraph_sym.list_inputs()
+
+ print(subgraph_sym.list_outputs())
+ assert shapes == subgraph_sym.infer_shape()
+ assert types == subgraph_sym.infer_type()
+
+ regular_exec = regular_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+ subgraph_exec = subgraph_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+
+ for name in input_names:
+ regular_exec.arg_dict[name][:] = mx.nd.random.normal(
+ shape=regular_exec.arg_dict[name].shape)
+ subgraph_exec.arg_dict[name][:] = regular_exec.arg_dict[name]
+
+ subgraph_exec.forward()
+ regular_exec.forward()
+ mx.nd.waitall()
+ assert (subgraph_exec.outputs[0] - regular_exec.outputs[0]).abs().sum().asscalar() == 0.0
+
+
+def test_input_name_order():
+ def check_input_order(sym, op_names):
+ out = SymbolHandle()
+ check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)),
+ c_str_array(op_names), ctypes.byref(out)))
+
+ new_sym = Symbol(out)
+ #print(sym.list_inputs())
+ #print(new_sym.list_inputs())
+ assert new_sym.list_inputs() == sym.list_inputs()
+ assert new_sym.list_arguments() == sym.list_arguments()
+ assert new_sym.list_auxiliary_states() == sym.list_auxiliary_states()
+ #print(new_sym.list_arguments())
+ #print(new_sym.list_auxiliary_states())
+ #print('original outputs: %s' % sym.list_outputs())
+ #print('new sym outputs: %s' % new_sym.list_outputs())
+
+ def test_network_structure_1():
+ data1 = mx.sym.var('data1')
+ data2 = mx.sym.var('data2')
+ conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
+ conv2 = mx.sym.Convolution(data=data2, weight=data1, no_bias=True, kernel=(2, 2), num_filter=1)
+ out = mx.sym.Group([conv1, conv2])
+ check_input_order(out, ['Convolution'])
+
+ def test_network_structure_2():
+ data1 = mx.sym.var('data1')
+ data2 = mx.sym.var('data2')
+ conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
+ conv2 = mx.sym.Convolution(data=data2, weight=data1, no_bias=True, kernel=(2, 2), num_filter=1)
+ out = conv1 + conv2
+ check_input_order(out, ['Convolution'])
+ check_input_order(out, ['Convolution', '_Plus', 'elemwise_add', '_plus'])
+
+ def test_network_structure_3():
+ # this tests whether the partitioning algorithm can deal with cycles
+ data = mx.sym.var('data')
+ ret = mx.sym.exp(data)
+ ret1 = mx.sym.cos(ret)
+ ret2 = mx.sym.sin(ret)
+ ret = ret1 + ret2
+ check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+ check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+
+ def test_network_structure_4():
+ # this tests whether the partitioned sym can distinguish in_args and aux_states
+ data = mx.sym.var('data')
+ ret = mx.sym.exp(data)
+ ret1 = mx.sym.cos(ret)
+ ret2 = mx.sym.sin(ret)
+ ret = ret1 + ret2
+ ret = mx.sym.BatchNorm(ret)
+ ret = mx.sym.BatchNorm(ret)
+ check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+ check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+ check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+ check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+ check_input_order(ret, ['exp', 'BatchNorm'])
+ check_input_order(ret, ['BatchNorm'])
+
+ test_network_structure_1()
+ test_network_structure_2()
+ test_network_structure_3()
+ test_network_structure_4()
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()