You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/01/05 15:53:13 UTC
[tvm] branch main updated: [microNPU][2a] Add CascaderGraph for cascading analysis (#9469)
This is an automated email from the ASF dual-hosted git repository.
mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 72d3efe [microNPU][2a] Add CascaderGraph for cascading analysis (#9469)
72d3efe is described below
commit 72d3efe1575a4ee5f4287edb35015ecb7d25c4b0
Author: Matthew Barrett <55...@users.noreply.github.com>
AuthorDate: Wed Jan 5 15:52:42 2022 +0000
[microNPU][2a] Add CascaderGraph for cascading analysis (#9469)
A CascaderGraph augments a TE graph with additional
information needed by the cascading algorithms. This
includes defining a strict ordering on the operators
as well as including all the Propagators needed to
do the affine analysis of cascades.
The CascaderGraph consists of two object types, Parts
and Tensors. A Part is an augmented operator which
includes the Propagators and a Tensor is similar to a
TE tensor but stores additional information like
compression ratio.
---
cmake/modules/contrib/EthosU.cmake | 7 +-
python/tvm/contrib/ethosu/cascader/__init__.py | 4 +-
python/tvm/contrib/ethosu/cascader/graph.py | 170 +++++++++++
.../ethosu/cascader/{__init__.py => parts.py} | 27 +-
src/contrib/ethosu/cascader/common.h | 25 +-
src/contrib/ethosu/cascader/graph.cc | 257 +++++++++++++++++
src/contrib/ethosu/cascader/graph.h | 321 +++++++++++++++++++++
src/contrib/ethosu/cascader/parts/inline.cc | 66 +++++
src/contrib/ethosu/cascader/parts/inline.h | 80 +++++
.../contrib/test_ethosu/cascader/__init__.py | 2 +-
.../contrib/test_ethosu/cascader/test_graph.py | 134 +++++++++
11 files changed, 1083 insertions(+), 10 deletions(-)
diff --git a/cmake/modules/contrib/EthosU.cmake b/cmake/modules/contrib/EthosU.cmake
index 7af7b48..0edeae3 100644
--- a/cmake/modules/contrib/EthosU.cmake
+++ b/cmake/modules/contrib/EthosU.cmake
@@ -18,12 +18,15 @@
if(USE_ETHOSU)
tvm_file_glob(GLOB COMPILER_ETHOSU_SRCS
src/relay/backend/contrib/ethosu/*
- src/contrib/ethosu/cascader/*)
+ src/contrib/ethosu/cascader/*
+ src/contrib/ethosu/cascader/parts/*)
list(APPEND COMPILER_SRCS ${COMPILER_ETHOSU_SRCS})
else()
# Keeping just utils.cc because it has Object definitions
# used by python side
tvm_file_glob(GLOB COMPILER_ETHOSU_SRCS
- src/relay/backend/contrib/ethosu/utils.cc)
+ src/relay/backend/contrib/ethosu/utils.cc
+ src/contrib/ethosu/cascader/*
+ src/contrib/ethosu/cascader/parts/*)
list(APPEND COMPILER_SRCS ${COMPILER_ETHOSU_SRCS})
endif(USE_ETHOSU)
diff --git a/python/tvm/contrib/ethosu/cascader/__init__.py b/python/tvm/contrib/ethosu/cascader/__init__.py
index 0093592..bf06d00 100644
--- a/python/tvm/contrib/ethosu/cascader/__init__.py
+++ b/python/tvm/contrib/ethosu/cascader/__init__.py
@@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""The NPU cascading planner.
+"""The NPU cascader.
This component performs inter-operator scheduling to optimize
for both performance and memory usage on Arm(R) Ethos(TM)-U NPUs.
"""
from .stripe_config import StripeConfig
from .propagator import Propagator
+from .graph import PerformanceInfo, Tensor, Part, TESubgraph, CascaderGraph
+from .parts import InlinePart
diff --git a/python/tvm/contrib/ethosu/cascader/graph.py b/python/tvm/contrib/ethosu/cascader/graph.py
new file mode 100644
index 0000000..001bbbf
--- /dev/null
+++ b/python/tvm/contrib/ethosu/cascader/graph.py
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Graph objects to define compute graphs for the NPU cascader."""
+from typing import List
+from collections import namedtuple
+import tvm._ffi
+
+from tvm.runtime import Object
+
+from .stripe_config import StripeConfig
+from . import _ffi_api
+
+
+TESubgraph = namedtuple("TESubgraph", ["input_tensors", "output_tensor"])
+
+
+@tvm._ffi.register_object("contrib.ethosu.cascader.PerformanceInfo")
+class PerformanceInfo(Object):
+ """PerformanceInfo class"""
+
+ @property
+ def compute_cycles(self):
+ return self._compute_cycles
+
+ @property
+ def read_bytes(self):
+ return list(self._read_bytes)
+
+ @property
+ def write_bytes(self):
+ return self._write_bytes
+
+
+@tvm._ffi.register_object("contrib.ethosu.cascader.Tensor")
+class Tensor(Object):
+ """Tensor class"""
+
+ def __init__(self, shape, dtype, is_constant=False, compression_ratio=1):
+ self.__init_handle_by_constructor__(
+ _ffi_api.Tensor, shape, dtype, is_constant, compression_ratio
+ )
+
+ def add_producer(self, part):
+ _ffi_api.TensorAddProducer(self, part)
+
+ def add_consumer(self, part):
+ _ffi_api.TensorAddConsumer(self, part)
+
+ @property
+ def producers(self):
+ return list(self._producers)
+
+ @property
+ def consumers(self):
+ return list(self._consumers)
+
+ @property
+ def shape(self):
+ return list(self._shape)
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def is_constant(self):
+ return self._is_constant
+
+ @property
+ def compression_ratio(self):
+ return self._compression_ratio
+
+ @property
+ def size(self):
+ return self._size
+
+
+class Part(Object):
+ """Part base class"""
+
+ def set_input(self, index: int, tensor: Tensor):
+ _ffi_api.PartSetInput(self, index, tensor)
+
+ def set_output(self, tensor: Tensor):
+ _ffi_api.PartSetOutput(self, tensor)
+
+ def calculate_input_stripe_configs(
+ self, output_stripe_config: StripeConfig
+ ) -> List[StripeConfig]:
+ return list(_ffi_api.PartCalculateInputStripeConfigs(self, output_stripe_config))
+
+ def get_stripe_align_hint(self) -> List[int]:
+ return list(_ffi_api.PartGetStripeAlignHint(self))
+
+ def get_performance_info(
+ self, stripe_config: StripeConfig, is_rolling: bool
+ ) -> PerformanceInfo:
+ return _ffi_api.PartGetPerformanceInfo(self, stripe_config, is_rolling)
+
+ @property
+ def input_tensors(self):
+ return list(self._input_tensors)
+
+ @property
+ def output_tensor(self):
+ return self._output_tensor
+
+ @property
+ def propagators(self):
+ return list(self._propagators)
+
+ @property
+ def in_line(self):
+ return self._in_line
+
+ @property
+ def subgraph(self):
+ return TESubgraph(list(self._te_input_tensors), self._te_output_tensor)
+
+
+@tvm._ffi.register_object("contrib.ethosu.cascader.CascaderGraph")
+class CascaderGraph(Object):
+ """A class to describe a graph of Parts and Tensors used by the cascader.
+
+ This class describes a graph consisting of two object types: Tensors and Parts.
+ It defines a topological ordering on the graph such that each Part and Tensor has a
+ position in the ordering. This ordering is used by the Plan and Proposal generation
+ algorithms. It is also the ordering the Parts are expected to be executed in.
+
+ In addition to defining an ordering, the Parts and Tensors are also all given unique
+ IDs which they can be referred to by."""
+
+ def __init__(self, input_tensors: List[Tensor], output_tensors: List[Tensor]):
+ self.__init_handle_by_constructor__(_ffi_api.CascaderGraph, input_tensors, output_tensors)
+
+ def get_part_id(self, part: Part) -> int:
+ return _ffi_api.CascaderGraphGetPartID(self, part)
+
+ def get_tensor_id(self, tensor: Tensor) -> int:
+ return _ffi_api.CascaderGraphGetTensorID(self, tensor)
+
+ @property
+ def input_tensors(self):
+ return list(self._input_tensors)
+
+ @property
+ def output_tensors(self):
+ return list(self._output_tensors)
+
+ @property
+ def tensor_order(self):
+ return list(self._tensor_order)
+
+ @property
+ def part_order(self):
+ return list(self._part_order)
diff --git a/python/tvm/contrib/ethosu/cascader/__init__.py b/python/tvm/contrib/ethosu/cascader/parts.py
similarity index 59%
copy from python/tvm/contrib/ethosu/cascader/__init__.py
copy to python/tvm/contrib/ethosu/cascader/parts.py
index 0093592..48d2d77 100644
--- a/python/tvm/contrib/ethosu/cascader/__init__.py
+++ b/python/tvm/contrib/ethosu/cascader/parts.py
@@ -14,10 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""The NPU cascading planner.
+"""Parts used by the NPU cascader."""
+from typing import List
+import tvm._ffi
-This component performs inter-operator scheduling to optimize
-for both performance and memory usage on Arm(R) Ethos(TM)-U NPUs.
-"""
-from .stripe_config import StripeConfig
from .propagator import Propagator
+from .graph import Part, TESubgraph
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("contrib.ethosu.cascader.InlinePart")
+class InlinePart(Part):
+ """InlinePart class"""
+
+ def __init__(
+ self,
+ te_subgraph: TESubgraph,
+ propagators: List[Propagator],
+ ):
+ self.__init_handle_by_constructor__(
+ _ffi_api.InlinePart,
+ te_subgraph.input_tensors,
+ te_subgraph.output_tensor,
+ propagators,
+ )
diff --git a/src/contrib/ethosu/cascader/common.h b/src/contrib/ethosu/cascader/common.h
index 07e1a36..ec62861 100644
--- a/src/contrib/ethosu/cascader/common.h
+++ b/src/contrib/ethosu/cascader/common.h
@@ -27,6 +27,8 @@
#include <tvm/ir/expr.h>
#include <tvm/runtime/container/array.h>
+#include <functional>
+#include <numeric>
#include <vector>
namespace tvm {
@@ -51,6 +53,22 @@ inline Array<Integer> make_array(const std::vector<int>& vec) {
}
/*!
+ * \brief Make a tvm::Array<Integer> from a size_t vector.
+ * \param vec The size_t vector.
+ * \return The Integer Array.
+ * \note Array<Integer>(std::vector<size_t>) doesn't work as this implicit
+ * type conversion fails. This is why this helper is required.
+ */
+inline Array<Integer> make_array(const std::vector<size_t>& vec) {
+ Array<Integer> arr;
+ arr.resize(vec.size());
+ for (unsigned int i = 0; i < vec.size(); ++i) {
+ arr.Set(i, Integer(vec[i]));
+ }
+ return arr;
+}
+
+/*!
* \brief Make a tvm::Array<FloatImm> from an float vector.
* \param vec The float vector.
* \return The FloatImm Array.
@@ -69,7 +87,7 @@ inline Array<FloatImm> make_array(const std::vector<float>& vec) {
* \param arr The Array.
* \return The vector.
*/
-template <class T, class tvm_T>
+template <typename T, typename tvm_T>
inline std::vector<T> make_vector(const Array<tvm_T>& arr) {
std::vector<T> vec(arr.size());
for (unsigned int i = 0; i < arr.size(); ++i) {
@@ -103,6 +121,11 @@ inline std::size_t hash_vector(const std::vector<T>& vec) {
return seed;
}
+template <class T>
+inline T mul_reduce(const std::vector<T>& vec) {
+ return std::accumulate(vec.begin(), vec.end(), 1, std::multiplies<T>());
+}
+
} // namespace cascader
} // namespace ethosu
} // namespace contrib
diff --git a/src/contrib/ethosu/cascader/graph.cc b/src/contrib/ethosu/cascader/graph.cc
new file mode 100644
index 0000000..a930c26
--- /dev/null
+++ b/src/contrib/ethosu/cascader/graph.cc
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "graph.h"
+
+#include <tvm/node/reflection.h>
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <stack>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "common.h"
+#include "stripe_config.h"
+
+namespace tvm {
+namespace contrib {
+namespace ethosu {
+namespace cascader {
+
+void PerformanceInfoNode::VisitAttrs(AttrVisitor* v) {
+ int compute_cycles_int = static_cast<int>(compute_cycles);
+ v->Visit("_compute_cycles", &compute_cycles_int);
+ Array<Integer> tmp_reads = make_array(read_bytes);
+ v->Visit("_read_bytes", &tmp_reads);
+ int write_bytes_int = static_cast<int>(write_bytes);
+ v->Visit("_write_bytes", &write_bytes_int);
+}
+
+TVM_REGISTER_NODE_TYPE(PerformanceInfoNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PerformanceInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const PerformanceInfoNode*>(ref.get());
+ p->stream << "PerformanceInfo(compute_cycles=" << node->compute_cycles << ", read_bytes=[";
+ for (auto rb : node->read_bytes) {
+ p->stream << rb << ", ";
+ }
+ p->stream << "], write_bytes=" << node->write_bytes << ")";
+ });
+
+void TensorNode::VisitAttrs(AttrVisitor* v) {
+ Array<Integer> tmp_arr = make_array(shape_);
+ v->Visit("_shape", &tmp_arr);
+ v->Visit("_dtype", &dtype_);
+ v->Visit("_is_constant", &is_constant_);
+ double compression_ratio = static_cast<double>(compression_ratio_);
+ v->Visit("_compression_ratio", &compression_ratio);
+ Array<Part> tmp_prods(producers_);
+ v->Visit("_producers", &tmp_prods);
+ Array<Part> tmp_cons(consumers_);
+ v->Visit("_consumers", &tmp_cons);
+ v->Visit("_size", &size_);
+}
+
+Tensor::Tensor(const std::vector<int>& shape, DataType dtype, bool is_constant = false,
+ float compression_ratio = 1.0) {
+ auto n = make_object<TensorNode>();
+ n->shape_ = std::move(shape);
+ n->dtype_ = dtype;
+ n->is_constant_ = is_constant;
+ n->compression_ratio_ = compression_ratio;
+ n->size_ = mul_reduce(n->shape_) * n->dtype_.bytes();
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Tensor")
+ .set_body_typed([](Array<Integer> shape, DataType dtype, bool is_constant,
+ double compression_ratio) {
+ std::vector<int> vshape = make_vector<int, Integer>(shape);
+ return Tensor(vshape, dtype, is_constant, compression_ratio);
+ });
+
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorAddProducer")
+ .set_body_method<Tensor>(&TensorNode::AddProducer);
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorAddConsumer")
+ .set_body_method<Tensor>(&TensorNode::AddConsumer);
+
+TVM_REGISTER_NODE_TYPE(TensorNode);
+
+void PartNode::VisitAttrs(AttrVisitor* v) {
+ Array<Propagator> tmp_prp(propagators_);
+ v->Visit("_propagators", &tmp_prp);
+ Array<Tensor> tmp_ins(input_tensors_);
+ v->Visit("_input_tensors", &tmp_ins);
+ v->Visit("_output_tensor", &output_tensor_);
+ v->Visit("_in_line", &in_line_);
+ Array<te::Tensor> tmp_te_ins(subgraph_.input_tensors);
+ v->Visit("_te_input_tensors", &tmp_te_ins);
+ v->Visit("_te_output_tensor", &subgraph_.output_tensor);
+}
+
+void PartNode::SetInput(uint64_t input_index, const Tensor& input_tensor) {
+ ICHECK_LT(input_index, input_tensors_.size());
+ input_tensors_[input_index] = std::move(input_tensor);
+}
+
+std::vector<StripeConfig> PartNode::CalculateInputStripeConfigs(
+ const StripeConfig& output_stripe_config) {
+ std::vector<StripeConfig> input_stripe_configs;
+ for (const auto& propagator : propagators_) {
+ input_stripe_configs.push_back(propagator->propagate(output_stripe_config));
+ }
+ return input_stripe_configs;
+}
+
+const std::vector<int> PartNode::GetStripeAlignHint() const {
+ ICHECK_GT(propagators_.size(), 0);
+ size_t dims = propagators_[0]->GetOutputDims();
+ std::vector<int> compute_quantum(dims);
+ for (size_t i = 0; i < dims; i++) {
+ compute_quantum[i] = 1;
+ }
+ return compute_quantum;
+}
+
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartSetInput")
+ .set_body_method<Part>(&PartNode::SetInput);
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartSetOutput")
+ .set_body_method<Part>(&PartNode::SetOutput);
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartCalculateInputStripeConfigs")
+ .set_body_typed([](Part part, StripeConfig output_stripe_config) {
+ auto input_stripe_configs = part->CalculateInputStripeConfigs(output_stripe_config);
+ return Array<StripeConfig>(input_stripe_configs);
+ });
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetStripeAlignHint").set_body_typed([](Part part) {
+ std::vector<int> align_hint = part->GetStripeAlignHint();
+ return make_array(align_hint);
+});
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetPerformanceInfo")
+ .set_body_typed([](Part part, StripeConfig stripe_config, bool is_rolling) {
+ return part->GetPerformanceInfo(stripe_config, is_rolling);
+ });
+
+CascaderGraphNode::CascaderGraphNode(std::vector<Tensor> input_tensors,
+ std::vector<Tensor> output_tensors)
+ : input_tensors_(input_tensors), output_tensors_(output_tensors) {
+ Init_();
+}
+
+bool VisitedInputs(
+ const Part& part,
+ const std::unordered_set<Tensor, ObjectPtrHash, ObjectPtrEqual>& visited_tensors) {
+ for (const auto& input_tensor : part->GetInputTensors()) {
+ if (visited_tensors.find(input_tensor) == visited_tensors.end()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void CascaderGraphNode::Init_() {
+ std::stack<Tensor> stack;
+ std::unordered_set<Tensor, ObjectPtrHash, ObjectPtrEqual> visited_tensors;
+ std::unordered_set<Part, ObjectPtrHash, ObjectPtrEqual> visited_parts;
+ for (const auto& input : input_tensors_) {
+ stack.push(input);
+ }
+ // Visit the Parts/Tensors in depth-first order using a non-recursive algorithm
+ while (!stack.empty()) {
+ Tensor tensor = stack.top();
+ stack.pop();
+ if (visited_tensors.find(tensor) == visited_tensors.end()) {
+ visited_tensors.insert(tensor);
+ tensor_order_.push_back(tensor);
+ for (const auto& part : tensor->GetConsumers()) {
+ if (visited_parts.find(part) == visited_parts.end()) {
+ // Only visit a Part once we've visited all its input Tensors
+ if (!VisitedInputs(part, visited_tensors)) continue;
+ visited_parts.insert(part);
+ part_order_.push_back(part);
+ stack.push(part->GetOutputTensor());
+ }
+ }
+ }
+ }
+ std::reverse(tensor_order_.begin(), tensor_order_.end());
+ std::reverse(part_order_.begin(), part_order_.end());
+ int id = 0;
+ for (const auto& part : part_order_) {
+ part_id_map_[part] = id;
+ id++;
+ }
+ id = 0;
+ for (const auto& tensor : tensor_order_) {
+ tensor_id_map_[tensor] = id;
+ id++;
+ }
+}
+
+void CascaderGraphNode::VisitAttrs(AttrVisitor* v) {
+ Array<Tensor> tmp_ins(input_tensors_);
+ v->Visit("_input_tensors", &tmp_ins);
+ Array<Tensor> tmp_outs(output_tensors_);
+ v->Visit("_output_tensors", &tmp_outs);
+ Array<Part> tmp_parr(part_order_);
+ v->Visit("_part_order", &tmp_parr);
+ Array<Tensor> tmp_tarr(tensor_order_);
+ v->Visit("_tensor_order", &tmp_tarr);
+}
+
+int CascaderGraphNode::GetPartID(const Part& part) const {
+ if (part_id_map_.find(part) == part_id_map_.end()) {
+ return -1;
+ }
+ return part_id_map_.at(part);
+}
+
+int CascaderGraphNode::GetTensorID(const Tensor& tensor) const {
+ if (tensor_id_map_.find(tensor) == tensor_id_map_.end()) {
+ return -1;
+ }
+ return tensor_id_map_.at(tensor);
+}
+
+CascaderGraph::CascaderGraph(std::vector<Tensor> input_tensors,
+ std::vector<Tensor> output_tensors) {
+ auto n = make_object<CascaderGraphNode>(input_tensors, output_tensors);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraph")
+ .set_body_typed([](Array<Tensor> input_tensors, Array<Tensor> output_tensors) {
+ std::vector<Tensor> vinput_tensors(input_tensors.begin(), input_tensors.end());
+ std::vector<Tensor> voutput_tensors(output_tensors.begin(), output_tensors.end());
+ return CascaderGraph(vinput_tensors, voutput_tensors);
+ });
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraphGetPartID")
+ .set_body_method<CascaderGraph>(&CascaderGraphNode::GetPartID);
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraphGetTensorID")
+ .set_body_method<CascaderGraph>(&CascaderGraphNode::GetTensorID);
+
+TVM_REGISTER_NODE_TYPE(CascaderGraphNode);
+
+} // namespace cascader
+} // namespace ethosu
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/ethosu/cascader/graph.h b/src/contrib/ethosu/cascader/graph.h
new file mode 100644
index 0000000..2bea890
--- /dev/null
+++ b/src/contrib/ethosu/cascader/graph.h
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/ethosu/cascader/graph.h
+ * \brief Graph objects (Tensor and Part) for the Ethos-U cascader
+ */
+#ifndef TVM_CONTRIB_ETHOSU_CASCADER_GRAPH_H_
+#define TVM_CONTRIB_ETHOSU_CASCADER_GRAPH_H_
+
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/object.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/tensor.h>
+
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "propagator.h"
+
+namespace tvm {
+namespace contrib {
+namespace ethosu {
+namespace cascader {
+
+class Tensor;
+class Part;
+class StripeConfig;
+
+/*! \brief A struct to hold a Tensor Expression subgraph */
+struct TESubgraph {
+ /*! \brief The input te::Tensors to the subgraph */
+ std::vector<te::Tensor> input_tensors;
+ /*! \brief The output te::Tensor of the subgraph */
+ te::Tensor output_tensor;
+};
+
+/*! \brief Node to hold performance information for a Part */
+class PerformanceInfoNode : public Object {
+ public:
+ void VisitAttrs(AttrVisitor* v);
+
+ /*! \brief The cycles to compute a block */
+ size_t compute_cycles;
+ /*! \brief The number of bytes read per input tensor */
+ std::vector<size_t> read_bytes;
+ /*! \brief The number of bytes written to the output tensor */
+ size_t write_bytes;
+
+ static constexpr const char* _type_key = "contrib.ethosu.cascader.PerformanceInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PerformanceInfoNode, Object);
+};
+
+/*!
+ * \brief A class to hold the performance information for a Part.
+ * \note The performance information for a Part is composed of 3 factors: the compute cycles,
+ * the number of bytes read from each input tensor and the number of bytes written to the output
+ * tensor. Bytes read/written is reported in favour of read/write bandwidth cycles so the
+ * calculation of the performance information can be re-used with different memory homing.
+ */
+class PerformanceInfo : public ObjectRef {
+ public:
+ PerformanceInfo(size_t compute_cycles, std::vector<size_t> read_bytes, size_t write_bytes) {
+ auto n = make_object<PerformanceInfoNode>();
+ n->compute_cycles = compute_cycles;
+ n->read_bytes = std::move(read_bytes);
+ n->write_bytes = write_bytes;
+ data_ = std::move(n);
+ }
+
+ TVM_DEFINE_OBJECT_REF_METHODS(PerformanceInfo, ObjectRef, PerformanceInfoNode);
+};
+
+/*! \brief Node to represent a Tensor */
+class TensorNode : public Object {
+ public:
+ void VisitAttrs(AttrVisitor* v);
+
+ /*! \return The shape of the tensor */
+ std::vector<int> GetShape() const { return shape_; }
+ /*! \return The data type of the tensor */
+ DataType GetDataType() const { return dtype_; }
+ /*! \return Whether the tensor stores a constant value */
+ bool IsConstant() const { return is_constant_; }
+ /*! \return The compression ratio of the tensor */
+ float GetCompressionRatio() const { return compression_ratio_; }
+ /*! \return The producers of the tensor */
+ const std::vector<Part> GetProducers() const { return producers_; }
+ /*! \return The consumers of the tensor */
+ const std::vector<Part> GetConsumers() const { return consumers_; }
+ /*! \return The size of the tensor in bytes */
+ int GetSize() const { return size_ * compression_ratio_; }
+
+ /*! \brief Add a producer of the tensor */
+ inline void AddProducer(const Part& part) { producers_.push_back(part); }
+ /*! \brief Add a consumer of the tensor */
+ inline void AddConsumer(const Part& part) { consumers_.push_back(part); }
+
+ static constexpr const char* _type_key = "contrib.ethosu.cascader.Tensor";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
+
+ protected:
+ friend class Tensor;
+
+ /*! \brief The shape of the tensor */
+ std::vector<int> shape_;
+ /*! \brief The data type of the tensor */
+ DataType dtype_;
+ /*! \brief Whether the tensor stores a constant value */
+ bool is_constant_;
+ /*! \brief The compression ratio of the tensor */
+ float compression_ratio_;
+ /*! \brief The producers of the tensor */
+ std::vector<Part> producers_;
+ /*! \brief The consumers of the tensor */
+ std::vector<Part> consumers_;
+ /*! \brief The size of the tensor in bytes */
+ int size_;
+};
+
+/*!
+ * \brief A class to describe a Tensor in a Cascader graph.
+ * \note Cascader graphs consist of two object types: Tensors and Parts. This class
+ * defines the Tensors which represent the tensors that are consumed and produced
+ * as part of the graph. They are augmented with information about their 'kind'
+ * (input/output/constant/intermediate), their default memory home (which memory they
+ * are expected to be allocated in) and a compression ratio where applicable (weights
+ * for instance are compressed).
+ */
+class Tensor : public ObjectRef {
+ public:
+ Tensor(const std::vector<int>& shape, DataType dtype, bool is_constant, float compression_ratio);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorNode);
+};
+
+/*! \brief Node to represent a Part */
+class PartNode : public Object {
+ public:
+ virtual void VisitAttrs(AttrVisitor* v);
+
+ /*! \return The TE subgraph represented by the Part */
+ const TESubgraph GetSubgraph() const { return subgraph_; }
+ /*! \return The output->input propagators */
+ const std::vector<Propagator> GetPropagators() const { return propagators_; }
+ /*! \return Whether the Part is inline */
+ bool IsInline() const { return in_line_; }
+ /*! \return The input tensors */
+ const std::vector<Tensor> GetInputTensors() const { return input_tensors_; }
+ /*! \return The output tensor */
+ const Tensor GetOutputTensor() const { return output_tensor_; }
+
+ /*! \brief Add a producer of the tensor */
+ void SetInput(uint64_t input_index, const Tensor& input_tensor);
+ /*! \brief Add a consumer of the tensor */
+ void SetOutput(const Tensor& output_tensor) { output_tensor_ = output_tensor; }
+ /*!
+ * \brief Calculate the input stripe configs for a given output stripe config using the
+ * Propagators. \param output_stripe_config The output stripe config to propagate. \return The
+ * calculated input stripe configs.
+ */
+ std::vector<StripeConfig> CalculateInputStripeConfigs(const StripeConfig& output_stripe_config);
+ /*!
+ * \brief Get the preferred alignment in each axis for a stripe of the Part.
+ * \note This is used to bias the selection of StripeConfigs towards those that are integer
+ * multiples of a tensor intrinsic used to compute the Part.
+ */
+ virtual const std::vector<int> GetStripeAlignHint() const;
+ /*!
+ * \brief Get the performance information for a given output stripe config.
+ * \param output_stripe_config The output stripe config to compute the performance for.
+ * \param is_rolling Whether the output config should be computed as a rolling buffer.
+ * \return The performance information containing the compute cycles and read/write bytes.
+ */
+ virtual const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config,
+ bool is_rolling) = 0;
+
+ static constexpr const char* _type_key = "contrib.ethosu.cascader.Part";
+ TVM_DECLARE_BASE_OBJECT_INFO(PartNode, Object);
+
+ protected:
+ friend class Part;
+
+ /*! \brief The Tensor Expression subgraph represented by the Part */
+ TESubgraph subgraph_;
+ /*! \brief The output->input propagators */
+ std::vector<Propagator> propagators_;
+ /*! \brief Whether the Part is computed in-line */
+ bool in_line_;
+ /*! \brief The input tensors */
+ std::vector<Tensor> input_tensors_;
+ /*! \brief The output tensor */
+ Tensor output_tensor_;
+};
+
+/*!
+ * \brief A class to describe a Part in a Cascader graph.
+ * \note Cascader graphs consist of two object types: Tensors and Parts. This class
+ * defines the Parts which represent the operations which produce and consume Tensors.
+ *
+ * A Part can represent one or more Tensor Expression compute operations but the subgraph
+ * it represents must have only a single output. Multiple TE compute operations should be
+ * represented under a single Part if the intermediate tensors between them won't be
+ * realized. This is a common pattern in Ethos-U where a sequence of TE compute operations
+ * are used to represent a single hardware primitive operation.
+ *
+ * Parts contain a Propagator per input which describes how a given output stripe config
+ * should be transformed into an input stripe config for each input. This is essential
+ * to analyse both the performance of Parts (determining the data that will be read) and
+ * in cascading Parts together (determining compatible stripe config choices).
+ *
+ * A Part can be marked as 'in_line', in which case it is assumed that it doesn't need to
+ * allocate space for its output tensor.
+ *
+ * This is only a base class and concrete Parts must be derived from it, implementing a
+ * function to model the performance of the Part as well as to determine its compute
+ * quantum.
+ */
+class Part : public ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Part, ObjectRef, PartNode);
+};
+
+/*! \brief Node to represent a CascaderGraph */
+class CascaderGraphNode : public Object {
+ public:
+ CascaderGraphNode() {}
+ CascaderGraphNode(std::vector<Tensor> input_tensors, std::vector<Tensor> output_tensors);
+
+ void VisitAttrs(AttrVisitor* v);
+
+ /*! \return The input Tensors of the CascaderGraph */
+ std::vector<Tensor> GetInputTensors() const { return input_tensors_; }
+ /*! \return The output Tensors of the CascaderGraph */
+ std::vector<Tensor> GetOutputTensors() const { return output_tensors_; }
+ /*! \return The order of the Parts in the CascaderGraph */
+ std::vector<Part> GetPartOrder() const { return part_order_; }
+ /*!
+ * \brief Get the ID of a Part in the CascaderGraph.
+ * \param part The Part to get the ID of.
+ * \return The ID of the Part in the CascaderGraph.
+ * \note Each Part is given a unique ID within the CascaderGraph.
+ */
+ int GetPartID(const Part& part) const;
+ /*!
+ * \brief Get the ID of a Tensor in the CascaderGraph.
+ * \param tensor The Tensor to get the ID of.
+ * \return The ID of the Tensor in the CascaderGraph.
+ * \note Each Tensor is given a unique ID within the CascaderGraph.
+ */
+ int GetTensorID(const Tensor& tensor) const;
+
+ static constexpr const char* _type_key = "contrib.ethosu.cascader.CascaderGraph";
+ TVM_DECLARE_FINAL_OBJECT_INFO(CascaderGraphNode, Object);
+
+ protected:
+ /*!
+ * \brief Initialize the CascaderGraph by defining a topological ordering.
+ * \note This will traverse the Parts and Tensors using a depth-first
+ * visiting pattern and use the traversal order to initialize both the
+ * 'order' vectors and the ID maps. The order vectors define the ordering
+ * that the cascader expects the CascaderGraph to be executed in, but reversed.
+ * The ID maps assign a unique integer ID to each Part and Tensor corresponding
+ * to their position in their respective order vector.
+ */
+ void Init_();
+
+ /*! \brief The input Tensors of the CascaderGraph */
+ std::vector<Tensor> input_tensors_;
+ /*! \brief The output Tensors of the CascaderGraph */
+ std::vector<Tensor> output_tensors_;
+ /*! \brief The order of the Tensors in the CascaderGraph */
+ std::vector<Tensor> tensor_order_;
+ /*! \brief The order of the Parts in the CascaderGraph */
+ std::vector<Part> part_order_;
+ /*! \brief A map between Parts in the CascaderGraph and their IDs */
+ std::unordered_map<Part, int, ObjectPtrHash, ObjectPtrEqual> part_id_map_;
+ /*! \brief A map between Tensors in the CascaderGraph and their IDs */
+ std::unordered_map<Tensor, int, ObjectPtrHash, ObjectPtrEqual> tensor_id_map_;
+};
+
+/*!
+ * \brief A class to describe a graph of Parts and Tensors used by the cascader.
+ * \note This class describes a graph consisting of two object types: Tensors and Parts.
+ * It defines a topological ordering on the graph such that each Part and Tensor has a
+ * position in the ordering. This ordering is used by the Plan and Proposal generation
+ * algorithms. It is also the ordering the Parts are expected to be executed in.
+ *
+ * In addition to defining an ordering, the Parts and Tensors are also all given unique
+ * IDs which they can be referred to by.
+ */
+class CascaderGraph : public ObjectRef {
+ public:
+ CascaderGraph(std::vector<Tensor> input_tensors, std::vector<Tensor> output_tensors);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(CascaderGraph, ObjectRef, CascaderGraphNode);
+};
+
+} // namespace cascader
+} // namespace ethosu
+} // namespace contrib
+} // namespace tvm
+
+#endif // TVM_CONTRIB_ETHOSU_CASCADER_GRAPH_H_
diff --git a/src/contrib/ethosu/cascader/parts/inline.cc b/src/contrib/ethosu/cascader/parts/inline.cc
new file mode 100644
index 0000000..ff5e055
--- /dev/null
+++ b/src/contrib/ethosu/cascader/parts/inline.cc
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "inline.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <utility>
+#include <vector>
+
+#include "../common.h"
+
+namespace tvm {
+namespace contrib {
+namespace ethosu {
+namespace cascader {
+
+const PerformanceInfo InlinePartNode::GetPerformanceInfo(const StripeConfig& output_stripe_config,
+ bool is_rolling) {
+ std::vector<size_t> read_bytes(input_tensors_.size());
+ PerformanceInfo info(0, read_bytes, 0);
+ return info;
+}
+
+InlinePart::InlinePart(const TESubgraph& subgraph, const std::vector<Propagator> propagators) {
+ auto n = make_object<InlinePartNode>();
+ ICHECK_GT(propagators.size(), 0) << "The Part must include at least one Propagator.";
+ n->subgraph_ = subgraph;
+ n->propagators_ = std::move(propagators);
+ n->in_line_ = true;
+ n->input_tensors_.resize(propagators.size());
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.InlinePart")
+ .set_body_typed([](Array<te::Tensor> subgraph_inputs, te::Tensor subgraph_output,
+ Array<Propagator> propagators) {
+ std::vector<te::Tensor> vsubgraph_inputs(subgraph_inputs.begin(), subgraph_inputs.end());
+ std::vector<Propagator> vpropagators(propagators.begin(), propagators.end());
+ TESubgraph subgraph;
+ subgraph.input_tensors = vsubgraph_inputs;
+ subgraph.output_tensor = subgraph_output;
+ return InlinePart(subgraph, vpropagators);
+ });
+
+TVM_REGISTER_NODE_TYPE(InlinePartNode);
+
+} // namespace cascader
+} // namespace ethosu
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/ethosu/cascader/parts/inline.h b/src/contrib/ethosu/cascader/parts/inline.h
new file mode 100644
index 0000000..44f2762
--- /dev/null
+++ b/src/contrib/ethosu/cascader/parts/inline.h
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/ethosu/cascader/parts/inline.h
+ * \brief Inline Part object
+ */
+#ifndef TVM_CONTRIB_ETHOSU_CASCADER_PARTS_INLINE_H_
+#define TVM_CONTRIB_ETHOSU_CASCADER_PARTS_INLINE_H_
+
+#include <tvm/runtime/object.h>
+
+#include <vector>
+
+#include "../graph.h"
+
+namespace tvm {
+namespace contrib {
+namespace ethosu {
+namespace cascader {
+
+/*! \brief Node to represent an inlined Part */
+class InlinePartNode : public PartNode {
+ public:
+ /*!
+ * \brief Get the performance information for a given output stripe config.
+ * \param output_stripe_config The output stripe config to compute the performance for.
+ * \param is_rolling Whether the output config should be computed as a rolling buffer.
+ * \return The performance information containing the compute cycles and read/write bytes.
+ */
+ const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config,
+ bool is_rolling) final;
+
+ static constexpr const char* _type_key = "contrib.ethosu.cascader.InlinePart";
+ TVM_DECLARE_FINAL_OBJECT_INFO(InlinePartNode, PartNode);
+
+ protected:
+ friend class InlinePart;
+};
+
+/*!
+ * \brief A class to describe a inlined Part in a Cascader graph.
+ * \note Inlined Parts have a few special properties. First by IsInline being true,
+ * the Cascader will not allocate any space for the outputs of the Part. This is because
+ * they will be directly consumed as they are produced by the following Part. Second, they
+ * are assumed to be 'free' and require no cycles to execute. Lastly, as they are 'free'
+ * the compute quantum is arbitrary, but by convention it is a single tensor element.
+ *
+ * Examples of inline Parts include strided_slice, reshape and concatenate - all of which
+ * get absorbed into the DMA functionality of Ethos-U compute primitives.
+ */
+class InlinePart : public Part {
+ public:
+ InlinePart(const TESubgraph& subgraph, const std::vector<Propagator> propagators);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(InlinePart, Part, InlinePartNode);
+};
+
+} // namespace cascader
+} // namespace ethosu
+} // namespace contrib
+} // namespace tvm
+
+#endif // TVM_CONTRIB_ETHOSU_CASCADER_PARTS_INLINE_H_
diff --git a/tests/python/contrib/test_ethosu/cascader/__init__.py b/tests/python/contrib/test_ethosu/cascader/__init__.py
index 7a08f7d..5d43783 100644
--- a/tests/python/contrib/test_ethosu/cascader/__init__.py
+++ b/tests/python/contrib/test_ethosu/cascader/__init__.py
@@ -14,4 +14,4 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Test infrastructure for the NPU planner"""
+"""Test infrastructure for the NPU cascader"""
diff --git a/tests/python/contrib/test_ethosu/cascader/test_graph.py b/tests/python/contrib/test_ethosu/cascader/test_graph.py
new file mode 100644
index 0000000..f00eb96
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/cascader/test_graph.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from tvm.contrib.ethosu.cascader import (
+ StripeConfig,
+ Propagator,
+ Tensor,
+ InlinePart,
+ TESubgraph,
+ CascaderGraph,
+)
+
+
+def test_tensor():
+ shape = [1, 2, 3]
+ dtype = "uint8"
+ is_constant = True
+ compression_ratio = 0.5
+ size = 6
+ tensor = Tensor(shape, dtype, is_constant, compression_ratio)
+ assert tensor.shape == shape
+ assert tensor.dtype == dtype
+ assert tensor.is_constant == is_constant
+ assert tensor.compression_ratio == compression_ratio
+ assert tensor.size == size
+
+
+def test_inline_part():
+ subgraph = TESubgraph([], None)
+ part = InlinePart(
+ subgraph,
+ [
+ Propagator(
+ [[0, 1, 0], [1, 0, 0], [0, 0, 1]],
+ [0, 0],
+ ),
+ ],
+ )
+ output_stripe_config = StripeConfig([2, 4], [8, 8], [2, 4], [1, 2], [4, 2], [0, 0])
+ input_stripe_config = StripeConfig([4, 2], [8, 8], [4, 2], [2, 1], [2, 4], [0, 0])
+
+ assert part.input_tensors == [None]
+ assert part.output_tensor == None
+ assert len(part.propagators) == 1
+ assert part.in_line == True
+ assert part.get_stripe_align_hint() == [1, 1]
+ performance_info = part.get_performance_info(output_stripe_config, is_rolling=False)
+ assert performance_info.compute_cycles == 0
+ assert performance_info.read_bytes == [0]
+ assert performance_info.write_bytes == 0
+ input_stripe_configs = part.calculate_input_stripe_configs(output_stripe_config)
+ assert len(input_stripe_configs) == 1
+ assert input_stripe_configs[0] == input_stripe_config
+
+
+def test_small_graph():
+ subgraph = TESubgraph([], None)
+ part_a = InlinePart(
+ subgraph,
+ [
+ Propagator(
+ [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
+ [0, 0],
+ ),
+ Propagator(
+ [[0, 1, 0], [1, 0, 0], [0, 0, 1]],
+ [-1, -1],
+ ),
+ ],
+ )
+ part_b = InlinePart(
+ subgraph,
+ [
+ Propagator(
+ [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
+ [0, 0],
+ ),
+ ],
+ )
+ tensor_1 = Tensor([10, 10], "uint8")
+ tensor_2 = Tensor([9, 9], "uint8")
+ tensor_3 = Tensor([10, 10], "uint8")
+ tensor_4 = Tensor([10, 10], "uint8")
+
+ part_a.set_input(0, tensor_1)
+ part_a.set_input(1, tensor_2)
+ part_a.set_output(tensor_3)
+ tensor_1.add_consumer(part_a)
+ tensor_2.add_consumer(part_a)
+ tensor_3.add_producer(part_a)
+ part_b.set_input(0, tensor_3)
+ part_b.set_output(tensor_4)
+ tensor_3.add_consumer(part_b)
+ tensor_4.add_producer(part_b)
+
+ assert part_a.input_tensors == [tensor_1, tensor_2]
+ assert part_a.output_tensor == tensor_3
+ assert part_b.input_tensors == [tensor_3]
+ assert part_b.output_tensor == tensor_4
+
+ assert tensor_1.producers == []
+ assert tensor_1.consumers == [part_a]
+ assert tensor_2.producers == []
+ assert tensor_2.consumers == [part_a]
+ assert tensor_3.producers == [part_a]
+ assert tensor_3.consumers == [part_b]
+ assert tensor_4.producers == [part_b]
+ assert tensor_4.consumers == []
+
+ graph = CascaderGraph([tensor_1, tensor_2], [tensor_4])
+ assert graph.input_tensors == [tensor_1, tensor_2]
+ assert graph.output_tensors == [tensor_4]
+ assert graph.part_order == [part_b, part_a]
+ for i, part in enumerate(graph.part_order):
+ assert graph.get_part_id(part) == i
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])