You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2024/02/16 14:46:28 UTC
(tvm) branch main updated: [Marvell BYOC]: Marvell AI Accelerator Integration - Phase 1 (#16570)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5645c52c6d [Marvell BYOC]: Marvell AI Accelerator Integration - Phase 1 (#16570)
5645c52c6d is described below
commit 5645c52c6d3105fb6c58cb7e1d983eff6ff26c19
Author: Krishna Bindumadhavan <31...@users.noreply.github.com>
AuthorDate: Fri Feb 16 20:16:23 2024 +0530
[Marvell BYOC]: Marvell AI Accelerator Integration - Phase 1 (#16570)
---
CMakeLists.txt | 2 +
cmake/config.cmake | 3 +
cmake/modules/LibInfo.cmake | 1 +
.../modules/contrib/Mrvl.cmake | 27 +-
.../__init__.py => docker/Dockerfile.demo_mrvl | 14 +-
docs/how_to/deploy/index.rst | 1 +
docs/how_to/deploy/mrvl.rst | 235 ++++
python/tvm/contrib/mrvl.py | 285 ++++
python/tvm/driver/tvmc/composite_target.py | 5 +
python/tvm/relay/op/contrib/__init__.py | 1 +
python/tvm/relay/op/contrib/mrvl.py | 918 +++++++++++++
python/tvm/testing/utils.py | 3 +
src/relay/backend/contrib/mrvl/codegen.cc | 1361 ++++++++++++++++++++
src/relay/backend/contrib/mrvl/compiler_attr.cc | 69 +
src/runtime/contrib/mrvl/mrvl_runtime.cc | 132 ++
src/support/libinfo.cc | 5 +
.../python/contrib/test_mrvl}/__init__.py | 14 +-
tests/python/contrib/test_mrvl/infrastructure.py | 105 ++
tests/python/contrib/test_mrvl/test_mrvl.py | 174 +++
tests/python/driver/tvmc/test_compiler.py | 18 +
tests/python/driver/tvmc/test_composite_target.py | 1 +
tests/python/driver/tvmc/test_target_options.py | 19 +
.../scripts/task_config_build_mrvl.sh | 28 +-
23 files changed, 3371 insertions(+), 50 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 058f477dbd..d10a18c4f1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -128,6 +128,7 @@ tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF)
tvm_option(USE_UMA "Build with UMA support" OFF)
tvm_option(USE_VERILATOR "Build with Verilator support" OFF)
tvm_option(USE_MSC "Enable Multi-System Compiler" OFF)
+tvm_option(USE_MRVL "Build with MRVL TVM support" OFF)
# include directories
include_directories(${CMAKE_INCLUDE_PATH})
@@ -581,6 +582,7 @@ include(cmake/modules/contrib/vllm.cmake)
include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
include(cmake/modules/RustExt.cmake)
+include(cmake/modules/contrib/Mrvl.cmake)
set(LIBINFO_FILE ${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc)
add_lib_info(${LIBINFO_FILE})
diff --git a/cmake/config.cmake b/cmake/config.cmake
index bf0a49b1aa..8caaeb7e1e 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -358,6 +358,9 @@ set(USE_HEXAGON_RPC OFF)
# Valid values are v65, v66, v68, v69, v73.
set(USE_HEXAGON_ARCH "v68")
+# Whether use MRVL codegen
+set(USE_MRVL OFF)
+
# Whether to use QHL library
set(USE_HEXAGON_QHL OFF)
diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake
index f6a678dca6..b971919acf 100644
--- a/cmake/modules/LibInfo.cmake
+++ b/cmake/modules/LibInfo.cmake
@@ -99,6 +99,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_MICRO="${USE_MICRO}"
TVM_INFO_USE_MIOPEN="${USE_MIOPEN}"
TVM_INFO_USE_MKL="${USE_MKL}"
+ TVM_INFO_USE_MRVL="${USE_MRVL}"
TVM_INFO_USE_MSVC_MT="${USE_MSVC_MT}"
TVM_INFO_USE_NNPACK="${USE_NNPACK}"
TVM_INFO_USE_OPENCL="${USE_OPENCL}"
diff --git a/python/tvm/relay/op/contrib/__init__.py b/cmake/modules/contrib/Mrvl.cmake
similarity index 64%
copy from python/tvm/relay/op/contrib/__init__.py
copy to cmake/modules/contrib/Mrvl.cmake
index 01708e8452..0329633619 100644
--- a/python/tvm/relay/op/contrib/__init__.py
+++ b/cmake/modules/contrib/Mrvl.cmake
@@ -14,16 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=wildcard-import
-"""Contrib modules."""
-from .register import get_pattern_table, register_pattern_table
-
-from .arm_compute_lib import *
-from .dnnl import *
-from .bnns import *
-from .coreml import *
-from .ethosn import *
-from .libtorch import *
-from .tensorrt import *
-from .cutlass import *
-from .clml import *
+include(ExternalProject)
+if(USE_MRVL)
+ # Mrvl Module
+ message(STATUS "Build with Mrvl support")
+ file(GLOB RUNTIME_MRVL_SRCS
+ src/runtime/contrib/mrvl/mrvl_runtime.cc
+ )
+ list(APPEND RUNTIME_SRCS ${RUNTIME_MRVL_SRCS})
+ file(GLOB COMPILER_MRVL_SRCS
+ src/relay/backend/contrib/mrvl/codegen.cc
+ src/relay/backend/contrib/mrvl/compiler_attr.cc
+ )
+ list(APPEND COMPILER_SRCS ${COMPILER_MRVL_SRCS})
+endif(USE_MRVL)
diff --git a/python/tvm/relay/op/contrib/__init__.py b/docker/Dockerfile.demo_mrvl
similarity index 70%
copy from python/tvm/relay/op/contrib/__init__.py
copy to docker/Dockerfile.demo_mrvl
index 01708e8452..a99345d07f 100644
--- a/python/tvm/relay/op/contrib/__init__.py
+++ b/docker/Dockerfile.demo_mrvl
@@ -14,16 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=wildcard-import
-"""Contrib modules."""
-from .register import get_pattern_table, register_pattern_table
-from .arm_compute_lib import *
-from .dnnl import *
-from .bnns import *
-from .coreml import *
-from .ethosn import *
-from .libtorch import *
-from .tensorrt import *
-from .cutlass import *
-from .clml import *
+# prebuild ci-cpu image
+FROM tlcpack/ci-cpu:20230604-060130-0af9ff90e
diff --git a/docs/how_to/deploy/index.rst b/docs/how_to/deploy/index.rst
index ac1e2a1276..4c3f30964b 100644
--- a/docs/how_to/deploy/index.rst
+++ b/docs/how_to/deploy/index.rst
@@ -176,6 +176,7 @@ target device without relying on RPC. See the following resources on how to do s
tensorrt
vitis_ai
bnns
+ mrvl
Additional Deployment How-Tos
-----------------------------
diff --git a/docs/how_to/deploy/mrvl.rst b/docs/how_to/deploy/mrvl.rst
new file mode 100644
index 0000000000..0b0b81ed34
--- /dev/null
+++ b/docs/how_to/deploy/mrvl.rst
@@ -0,0 +1,235 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Marvell Machine Learning Integration
+====================================
+
+1. Introduction
+---------------
+Marvell(R) supports a family of high performance Data Processing
+Units (DPUs) with integrated compute, high speed I/O and workload
+accelerators. These workload accelerators includes Marvell's
+Machine Learning Inference Processor (MLIP), a highly optimized,
+integrated inference engine.
+
+TVM supports Marvell's MLIP using the "mrvl" library. This partitions and
+compiles supported operations for accelerated execution on MLIP, or LLVM
+for general compute.
+
+For runtime, the library supports native execution on MLIP hardware
+as well as Marvell's ML simulator (mlModel).
+
+The library supports Marvell's Octeon family of processors with ML accelarators.
+
+This guide demonstrates building TVM with codegen and
+runtime enabled. It also provides example code to compile and run
+models using 'mrvl' runtime.
+
+2. Building TVM with mrvl support
+---------------------------------
+
+2.1 Clone TVM repo
+-------------------
+
+Refer to the following TVM documentation for cloning TVM
+https://tvm.apache.org/docs/install/from_source.html
+
+2.2 Build and start the TVM - mrvl docker container
+----------------------------------------------------
+
+.. code:: bash
+
+ ./docker/build.sh demo_mrvl bash # Build the docker container
+ ./docker/bash.sh tvm.demo_mrvl --env PYTHONPATH=$PWD/python # Load the docker image
+
+
+3. Build TVM inside the docker container with mrvl (inside tvm directory)
+-------------------------------------------------------------------------
+
+.. code:: bash
+
+ ./tests/scripts/task_config_build_mrvl.sh build
+ cd build
+ cmake ..
+ make -j$(nproc) # nproc = 4/8/.. (Number of Parallel jobs)
+
+4. Compiling a model using TVMC command line
+--------------------------------------------
+Models can be compiled and run for mrvl target using TVMC
+which is optimized for performance.
+
+Refer to the following TVMC documentation, for tvmc generic options.
+https://tvm.apache.org/docs/tutorial/tvmc_command_line_driver.html
+
+Additional mrvl-specific options may be added as attributes if
+necessary. The advanced usage is described in this document below.
+
+4.1 TVMC Compilation Flow for a model
+-------------------------------------
+
+Refer to the following TVM documentation, for compilation flow
+https://tvm.apache.org/docs/arch/index.html#example-compilation-flow
+
+
+4.2. TVMC - Command line option(s): Syntax for mrvl target
+----------------------------------------------------------
+
+Compiling an ONNX model using the tvmc for mrvl target.
+
+**Syntax:**
+
+.. code:: python
+
+ python3 -m tvm.driver.tvmc compile --target="mrvl, llvm"
+ --target-llvm-<options>
+ --target-mrvl-<options>
+ --<tvm-generic-options>
+ model_file.onnx
+
+Following is an example TVMC Compile command for an ARMv9 core and
+integrated MLIP cn10ka processor, using only 4 tiles in the block.
+
+**Example:**
+
+.. code:: python
+
+ python3 -m tvm.driver.tvmc compile --target="mrvl, llvm" \
+ --target-llvm-mtriple=aarch64-linux-gnu --target-llvm-mcpu=neoverse-n2 \
+ --target-mrvl-num_tiles=4 \
+ --cross-compiler aarch64-linux-gnu-gcc \
+ --output model.tar \
+ mnist-12.onnx
+
+
+4.3. TVMC Compiler: mrvl specific Command Line Options
+------------------------------------------------------
+
+.. code:: python
+
+ --target-mrvl-mcpu
+ --target-mrvl-num_tiles
+ --target-mrvl-mattr
+
+**Description of mrvl options**
+
+* mcpu:
+ The CPU class of Marvell(R) ML Inference Processor;
+ possible values = {cn10ka, cnf10kb}; defaults to cn10ka
+
+* num_tiles:
+ Maximum number of tiles that may be used, possible values = {1,2,4,8}, defaults to 8
+
+* mattr:
+ Attributes for mrvl; possible values = {quantize, wb_pin_ocm}
+
+ mattr specifies the data type, code generation options and optimizations.
+
+ *List of supported attributes are:*
+
+ **1. quantize**
+
+ Specify the data type. Possible values = {fp16, int8}.
+ Default is fp16, int8 is WIP and full support will be added in a future PR.
+
+ **2. wb_pin_ocm**
+
+ Optimize runtime by preloading a model's weights and bias into
+ the on chip memory. Possible values = {0, 1}. Default is 0 (no preload)
+
+5. Compilation - Generating model partitions
+--------------------------------------------
+
+In the TVMC mrvl flow, the model is partitioned into Marvell and LLVM regions.
+Building each partitioned Marvell subgraph generates serialized nodes.json and
+const.json. Partitioned nodes.json is the representation of the model graph which is
+suitable for the Marvell mmlc compiler. It is distributed separately via CDK
+
+**Model Partition**
+
+.. code:: bash
+
+ python3 -m tvm.driver.tvmc compile --target="mrvl, llvm \
+ -mtriple=aarch64-linux-gnu -mcpu=neoverse-n2" \
+ --cross-compiler aarch64-linux-gnu-gcc \
+ --target-mrvl-num_tiles=4 --output model.tar model.onnx
+
+
+6. Compiling a model using Python APIs
+--------------------------------------
+
+In addition to using TVMC, models can also be compiled and run using
+TVM Python API. Below is an example to compile the MNIST model. Support
+to run the model will be part of next PR by mrvl
+
+**Download MNIST model from the web**
+
+.. code:: bash
+
+ cd $HOME
+ wget https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-12.onnx
+
+**Import the TVM and other dependent modules**
+
+.. code:: python
+
+ import tvm, onnx, os
+ import numpy as np
+ import tvm.relay as relay
+ from tvm.relay.op.contrib.mrvl import partition_for_mrvl
+ from tvm.relay.build_module import build
+ from keras.datasets import mnist
+
+**Load model onnx file**
+
+.. code:: python
+
+ onnx_model = onnx.load("mnist-12.onnx")
+
+**Create a Relay graph from MNIST model**
+
+.. code:: python
+
+ shape_dict = {'Input3' : (1,1,28,28)}
+ mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
+
+**Define option dictionary and Partition the Model**
+
+Annotate and partition the graph for mrvl. All operations which are supported
+by the mrvl will be marked and offloaded to mrvl hardware accelerator. The rest of the
+operations will go through the regular LLVM compilation and code generation for ARM.
+
+.. code:: python
+
+ tvm_target = "llvm"
+
+ option_dict = {'num_tiles': 4}
+
+ mod = partition_for_mrvl(mod, params, **option_dict)
+
+**Build the Relay Graph**
+
+Build the Relay graph, using the new module returned by partition_for_mrvl.
+The target must always be a LLVM (ARM) target. ``partition_for_mrvl`` will
+pass the options from dictionary into the config parameters needed by the
+compiler backend, so there is no need to modify it - just pass it along
+to the PassContext so the values can be read during compilation.
+
+.. code:: python
+
+ with tvm.transform.PassContext(opt_level=3, config={"relay.ext.mrvl.options" : option_dict}):
+ model_lib = relay.build(mod, tvm_target, params=params)
diff --git a/python/tvm/contrib/mrvl.py b/python/tvm/contrib/mrvl.py
new file mode 100644
index 0000000000..cd0dab05ef
--- /dev/null
+++ b/python/tvm/contrib/mrvl.py
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, broad-except
+"""Utility to compile Marvell models"""
+
+import os
+import json
+import tvm
+import tvm._ffi
+
+
+@tvm._ffi.register_func("tvm.mrvl.GetNodesJSONString")
+def get_nodes_json_string(graph_json):
+ """This takes the graph_json string from MrvlJSONSerializer and adds / modifies
+ the json string to a form suitable for the Marvell Backend.
+
+ Parameters
+ ----------
+ graph_json: String
+ This is the graph_json string from the MrvlJSONSerializer
+
+ Returns
+ -------
+ nodes_json_string: string
+ This returns the nodes_json string which can be accepted by the Marvell backend.
+ """
+
+ dictionary = json.loads(graph_json)
+ # Add Marvell Index and rename "op" and "name" fields
+ mrvl_idx = 1
+ num_in = 0
+ for iterator in dictionary["nodes"]:
+ if iterator["op"] == "kernel":
+ iterator["op"] = "tvm_op"
+ iterator["attrs"]["mrvl_nodes_idx"] = [mrvl_idx]
+ iterator["attrs"]["kernel_const"] = {}
+ iterator["attrs"]["bias_const"] = {}
+ iterator["attrs"]["beta_const"] = {}
+ iterator["attrs"]["gamma_const"] = {}
+ iterator["attrs"]["var_const"] = {}
+ iterator["attrs"]["mean_const"] = {}
+ iterator["name"] = "tvmgen_mrvl_main" + "_" + str(mrvl_idx - 1)
+ mrvl_idx = mrvl_idx + 1
+ if iterator["op"] == "input":
+ iterator["attrs"]["layer_name"] = ["input"]
+ iterator["inputs"] = []
+ in_id = iterator["name"].split("_i")[-1]
+ iterator["input_id"] = [in_id]
+ iterator["attrs"]["dtype"] = iterator["attrs"]["dtype"][0]
+ iterator["attrs"]["shape"] = iterator["attrs"]["shape"][0]
+ if len(iterator["attrs"]["shape"][0]) == 2:
+ iterator["attrs"]["data_layout"] = ["NC"]
+ else:
+ iterator["attrs"]["data_layout"] = ["NCHW"]
+ # Infer Batch Size from the input shape
+ batch_size = iterator["attrs"]["shape"][0][0]
+ dictionary["batch_size"] = f"{batch_size}"
+ num_in = num_in + 1
+
+ # Create a new inputs to store only the previous node input and not the const inputs
+ for iterator in dictionary["nodes"]:
+ if iterator["op"] == "tvm_op":
+ list_prev = []
+ for prev in iterator["inputs"]:
+ if dictionary["nodes"][prev[0]]["op"] == "tvm_op":
+ mrvl_idx_prev = dictionary["nodes"][prev[0]]["attrs"]["mrvl_nodes_idx"][0]
+ list_prev.append([mrvl_idx_prev + num_in - 1, 0, 0])
+ if dictionary["nodes"][prev[0]]["op"] == "input":
+ idx_in = int(dictionary["nodes"][prev[0]]["input_id"][0])
+ list_prev.append([idx_in, 0, 0])
+ iterator["node_prev"] = list_prev
+
+ for iterator in dictionary["nodes"]:
+ if iterator["op"] == "tvm_op":
+ del iterator["inputs"]
+
+ for iterator in dictionary["nodes"]:
+ if iterator["op"] == "tvm_op":
+ iterator["inputs"] = iterator["node_prev"]
+
+ for iterator in dictionary["nodes"]:
+ if iterator["op"] == "tvm_op":
+ del iterator["node_prev"]
+
+ # Remove unneeded fields
+ del dictionary["node_row_ptr"]
+
+ # Patch up arg_nodes and heads to remove references to constant inputs
+ list_nodes = dictionary["arg_nodes"]
+ list_nodes_updated = []
+
+ for iterator in list_nodes:
+ if dictionary["nodes"][iterator]["op"] != "const":
+ if dictionary["nodes"][iterator]["op"] == "input":
+ input_name = dictionary["nodes"][iterator]["name"]
+ input_num_str = input_name.split("_i", 1)[1]
+ input_num = int(input_num_str)
+ list_nodes_updated.append(input_num)
+ else:
+ list_nodes_updated.append(
+ dictionary["nodes"][iterator]["attrs"]["mrvl_nodes_idx"][0]
+ )
+ dictionary["arg_nodes"] = list_nodes_updated
+
+ # Add additional data required by the runtime such as number of inputs
+ # and number of outputs to the subgraph
+ num_subgraph_inputs = str(len(list_nodes_updated))
+ dictionary["num_subgraph_inputs"] = f"{num_subgraph_inputs}"
+ list_heads = dictionary["heads"]
+ list_heads_updated = []
+ for iterator in list_heads:
+ if dictionary["nodes"][iterator[0]]["op"] != "const":
+ if iterator[0] != 0:
+ get_index = dictionary["nodes"][iterator[0]]["attrs"]["mrvl_nodes_idx"][0]
+ new_index = get_index + num_in - 1
+ list_heads_updated.append([new_index, 0, 0])
+ dictionary["heads"] = list_heads_updated
+
+ num_subgraph_outputs = str(len(list_heads_updated))
+ dictionary["num_subgraph_outputs"] = f"{num_subgraph_outputs}"
+
+ # Delete the constant nodes, these are not required for the constants file
+ dictionary["nodes"] = [
+ feature for feature in dictionary["nodes"] if "const" not in feature["op"]
+ ]
+
+ # Remove un-needed array nesting
+ for iterator in dictionary["nodes"]:
+ if iterator["op"] not in "input":
+ for it2 in iterator["attrs"]:
+ if it2 not in [
+ "num_inputs",
+ "num_outputs",
+ "mrvl_nodes_idx",
+ "mean_const",
+ "var_const",
+ "beta_const",
+ "kernel_const",
+ "bias_const",
+ "gamma_const",
+ ]:
+ iterator["attrs"][it2] = iterator["attrs"][it2][0]
+
+ # Now create the dltype and dlshape attributes
+ dltype = ["list_str"]
+ shape = ["list_shape"]
+ list_types = []
+ list_shapes = []
+ for iterator in dictionary["nodes"]:
+ list_types.append(iterator["attrs"]["dtype"][0])
+ list_shapes.append(iterator["attrs"]["shape"][0])
+ dltype.append(list_types)
+ shape.append(list_shapes)
+ dict_shape_type = {}
+ dict_shape_type["shape"] = shape
+ dict_shape_type["dltype"] = dltype
+ dictionary["attrs"] = dict_shape_type
+
+ nodes_json_string = json.dumps(dictionary)
+ return nodes_json_string
+
+
+@tvm._ffi.register_func("tvm.mrvl.ModifyConstNames")
+def modify_const_names(nodes_json_str, consts_json_str):
+ """This takes the graph module returned by relay.build an generates nodes and constant
+ meta data suitable for compilation by the back end.
+
+ Parameters
+ ----------
+ nodes_json_str: string
+ The nodes json string suitable for the Marvell backend.
+
+ consts_json_str: string
+ The consts_json_string generated by the backend compiler.
+
+ Returns
+ -------
+ modified_nodes_consts: string
+ This returns a concatenated string of the nodes_json and modified
+ consts json file, seperated by a delimiter |. The modification to the
+ consts file is necessary since we have added the Merge Compiler Pass
+ which names the constants in a form unsuitable for the backend.
+ """
+
+ nodes = json.loads(nodes_json_str)
+ const = json.loads(consts_json_str)
+ for iterator in nodes["nodes"]:
+ hasBias = False
+ for attrs in iterator["attrs"]:
+ if attrs == "bias_const_name":
+ hasBias = True
+ for attrs in iterator["attrs"]:
+ if attrs == "kernel_const_name":
+ new_name = iterator["name"] + "_const_0"
+ const[new_name] = const.pop(iterator["attrs"][attrs][0])
+ iterator["attrs"][attrs][0] = new_name
+ map_kernel = {}
+ map_kernel["shape"] = const[new_name]["shape"]
+ map_kernel["dtype"] = const[new_name]["dtype"]
+ map_kernel["min"] = const[new_name]["min"]
+ map_kernel["max"] = const[new_name]["max"]
+ map_kernel["name"] = new_name
+ iterator["attrs"]["kernel_const"] = map_kernel
+ if attrs == "bias_const_name":
+ new_name = iterator["name"] + "_const_1"
+ const[new_name] = const.pop(iterator["attrs"][attrs][0])
+ iterator["attrs"][attrs][0] = new_name
+ bias_map = {}
+ bias_map["shape"] = const[new_name]["shape"]
+ bias_map["dtype"] = const[new_name]["dtype"]
+ bias_map["min"] = const[new_name]["min"]
+ bias_map["max"] = const[new_name]["max"]
+ bias_map["name"] = new_name
+ iterator["attrs"]["bias_const"] = bias_map
+ if attrs == "gamma_const_name":
+ if hasBias:
+ new_name = iterator["name"] + "_const_2"
+ else:
+ new_name = iterator["name"] + "_const_1"
+ const[new_name] = const.pop(iterator["attrs"][attrs][0])
+ iterator["attrs"][attrs][0] = new_name
+ gamma_map = {}
+ gamma_map["shape"] = const[new_name]["shape"]
+ gamma_map["dtype"] = const[new_name]["dtype"]
+ gamma_map["name"] = new_name
+ iterator["attrs"]["gamma_const"] = gamma_map
+ if attrs == "beta_const_name":
+ if hasBias:
+ new_name = iterator["name"] + "_const_3"
+ else:
+ new_name = iterator["name"] + "_const_2"
+ const[new_name] = const.pop(iterator["attrs"][attrs][0])
+ iterator["attrs"][attrs][0] = new_name
+ beta_map = {}
+ beta_map["shape"] = const[new_name]["shape"]
+ beta_map["dtype"] = const[new_name]["dtype"]
+ beta_map["name"] = new_name
+ iterator["attrs"]["beta_const"] = beta_map
+ if attrs == "mean_const_name":
+ if hasBias:
+ new_name = iterator["name"] + "_const_4"
+ else:
+ new_name = iterator["name"] + "_const_3"
+ const[new_name] = const.pop(iterator["attrs"][attrs][0])
+ iterator["attrs"][attrs][0] = new_name
+ mean_map = {}
+ mean_map["shape"] = const[new_name]["shape"]
+ mean_map["dtype"] = const[new_name]["dtype"]
+ mean_map["name"] = new_name
+ iterator["attrs"]["mean_const"] = mean_map
+ if attrs == "var_const_name":
+ if hasBias:
+ new_name = iterator["name"] + "_const_5"
+ else:
+ new_name = iterator["name"] + "_const_4"
+ const[new_name] = const.pop(iterator["attrs"][attrs][0])
+ iterator["attrs"][attrs][0] = new_name
+ var_map = {}
+ var_map["shape"] = const[new_name]["shape"]
+ var_map["dtype"] = const[new_name]["dtype"]
+ var_map["name"] = new_name
+ iterator["attrs"]["var_const"] = var_map
+
+ nodes_mod_str = json.dumps(nodes, indent=2)
+ const_mod_str = json.dumps(const, indent=2)
+ return nodes_mod_str + "|" + const_mod_str
+
+
+def get_working_dir():
+ """Obtain the current working directory from where tvm is invoked"""
+ return os.getcwd()
diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py
index b5d04cdba7..cfcf5a14c1 100644
--- a/python/tvm/driver/tvmc/composite_target.py
+++ b/python/tvm/driver/tvmc/composite_target.py
@@ -29,6 +29,7 @@ from tvm.relay.op.contrib.ethosu import partition_for_ethosu
from tvm.relay.op.contrib.bnns import partition_for_bnns
from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai
from tvm.relay.op.contrib.clml import partition_for_clml
+from tvm.relay.op.contrib.mrvl import partition_for_mrvl
from tvm.driver.tvmc import TVMCException
@@ -76,6 +77,10 @@ REGISTERED_CODEGEN = {
"config_key": None,
"pass_pipeline": partition_for_clml,
},
+ "mrvl": {
+ "config_key": "relay.ext.mrvl.options",
+ "pass_pipeline": partition_for_mrvl,
+ },
}
diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py
index 01708e8452..3a7b8db55f 100644
--- a/python/tvm/relay/op/contrib/__init__.py
+++ b/python/tvm/relay/op/contrib/__init__.py
@@ -27,3 +27,4 @@ from .libtorch import *
from .tensorrt import *
from .cutlass import *
from .clml import *
+from .mrvl import *
diff --git a/python/tvm/relay/op/contrib/mrvl.py b/python/tvm/relay/op/contrib/mrvl.py
new file mode 100644
index 0000000000..016e7ea7f6
--- /dev/null
+++ b/python/tvm/relay/op/contrib/mrvl.py
@@ -0,0 +1,918 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, broad-except
+"""Marvell Library supported operators."""
+
+import tvm
+from tvm import relay
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.expr_functor import ExprMutator, ExprVisitor
+from tvm.relay.expr import Call, TupleGetItem
+from tvm.contrib import mrvl as mrvl_contrib
+
+from ...dataflow_pattern import (
+ wildcard,
+ is_op,
+ is_constant,
+ is_tuple,
+ is_tuple_get_item,
+ is_var,
+)
+from .register import register_pattern_table
+from ..strategy.generic import is_depthwise_conv2d
+
+
+def partition_for_mrvl(
+ mod,
+ params=None,
+ **kwargs,
+):
+ """Partition the graph greedily into Marvell graph region(s) and a LLVM region(s). The LLVM
+ region will contain ops not supported by the Marvell backend.
+
+ Parameters
+ ----------
+ mod : Module
+ The module to run passes on.
+ params : Optional[Dict[str, NDArray]]
+ Constant input parameters.
+
+ Returns
+ -------
+ mod_mrvl_llvm_regions : annotated & partitioned module (of Mrvl region(s) & LLVM region(s))
+ """
+
+ # setup & register convert layout options
+ convert_layout_dict = {
+ "nn.conv2d": ["NHWC", "OHWI"],
+ "nn.max_pool2d": ["NHWC"],
+ "nn.avg_pool2d": ["NHWC"],
+ "nn.global_avg_pool2d": ["NHWC"],
+ }
+
+ mrvl_register_conv2d_attr_funcs_for_convert_layout()
+ mrvl_register_max_pool2d_attr_funcs_for_convert_layout()
+ mrvl_register_avg_pool2d_attr_funcs_for_convert_layout()
+ mrvl_register_global_avg_pool2d_attr_funcs_for_convert_layout()
+
+ if params:
+ mod["main"] = bind_params_by_name(mod["main"], params)
+
+ opt_level = 3
+ disabled_pass_list = ["AlterOpLayout"]
+ annotate_target_str = "mrvl"
+ annotate_target_include_non_call_ops = True
+
+ seq_tvmc_pre_repartition = tvm.transform.Sequential(
+ passes=[
+ relay.transform.InferType(),
+ MrvlRemoveDropoutPass(),
+ MrvlRemoveCopyPass(),
+ relay.transform.RemoveUnusedFunctions(),
+ relay.transform.FoldConstant(),
+ relay.transform.SimplifyExpr(),
+ relay.transform.InferType(),
+ relay.transform.ConvertLayout(convert_layout_dict),
+ relay.transform.FoldConstant(),
+ relay.transform.SimplifyExpr(),
+ relay.transform.InferType(),
+ relay.transform.MergeComposite(mrvl_pattern_table()),
+ relay.transform.AnnotateTarget(
+ annotate_target_str,
+ annotate_target_include_non_call_ops,
+ ),
+ relay.transform.MergeCompilerRegions(),
+ relay.transform.PartitionGraph(""),
+ relay.transform.InferType(),
+ ]
+ )
+
+ # convert layout back to NCHW for ops in main
+ desired_layouts_in_main = {
+ "nn.conv2d": ["NCHW", "OIHW"],
+ "nn.max_pool2d": ["NCHW"],
+ "nn.avg_pool2d": ["NCHW"],
+ "nn.global_avg_pool2d": ["NCHW"],
+ }
+
+ seq_tvmc_post_repartition = tvm.transform.Sequential(
+ passes=[
+ # Convert Layout of conv ops in main to NCHW (as expected by LLVM).
+ # This pass does not change layout of ops already partitioned into
+ # Marvell regions.
+ relay.transform.ConvertLayout(desired_layouts_in_main),
+ relay.transform.FoldConstant(),
+ relay.transform.SimplifyExpr(),
+ relay.transform.InferType(),
+ ]
+ )
+
+ with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass_list):
+ tmp_mod1 = seq_tvmc_pre_repartition(mod)
+ tmp_mod1 = repartition_mrvl_subgraphs(tmp_mod1)
+ tmp_mod1 = seq_tvmc_post_repartition(tmp_mod1)
+ mod_mrvl_llvm_regions = add_attributes(tmp_mod1, annotate_target_str, **kwargs)
+
+ return mod_mrvl_llvm_regions
+
+
+def is_activation(pattern):
+ """
+ Check if pattern in Marvell supported activations list
+ """
+ mrvl_activations = [
+ "nn.relu",
+ ]
+ activation_pattern = None
+ for ptrn in mrvl_activations:
+ activ = is_op(ptrn)
+ if activation_pattern is None:
+ activation_pattern = activ
+ else:
+ activation_pattern |= activ
+ pattern = pattern.optional(activation_pattern)
+ return pattern
+
+
+class IsComputeIntensiveGraph(ExprVisitor):
+ """
+ Visits the graph recursively and checks if it contains compute heavy ops like
+ convolutions and dense.
+ """
+
+ def __init__(self):
+ ExprVisitor.__init__(self)
+ self.is_compute_intensive = False
+
+ def visit_call(self, call):
+ compute_intensive_ops = {
+ "nn.conv2d",
+ "nn.dense",
+ }
+ if isinstance(call.op, tvm.tir.op.Op):
+ if str(call.op.name) in compute_intensive_ops:
+ self.is_compute_intensive = True
+
+ return super().visit_call(call)
+
+ def is_graph_compute_intensive(self, subgraph):
+ """
+ This function recursively visits the graph and checks if it's compute intensive"
+ """
+ self.visit(subgraph)
+ return self.is_compute_intensive
+
+
+class IsSupportedGraph(ExprVisitor):
+ """
+ Visits the graph recursively and checks if function inputs feed into
+ any unsupported ops.
+ """
+
+ def __init__(self, function):
+ ExprVisitor.__init__(self)
+ self.is_supported = True
+ self.function = function
+ self.input_op_list = []
+
+ def _check_legal(self, node, parent_call):
+ unsupported_ops = {
+ "mrvl.sum2d",
+ "mrvl.concat",
+ }
+
+ input_ops = {
+ "mrvl.reshape",
+ }
+
+ if isinstance(node, relay.Function):
+ if node.attrs["Composite"] in unsupported_ops:
+ self.is_supported = False
+ if node.attrs["Composite"] in input_ops:
+ self.input_op_list.append(parent_call)
+
+ def visit_call(self, call):
+ for args in call.args:
+ if args in self.function.params or args in self.input_op_list:
+ relay.analysis.post_order_visit(
+ call, lambda expr, parent_call=call: self._check_legal(expr, parent_call)
+ )
+
+ return super().visit_call(call)
+
+ def is_supported_subgraph(self):
+ """
+ This function recursively visits the graph and checks if graph is legal"
+ """
+ self.visit(self.function.body)
+ return self.is_supported
+
+
+def first_op_unsupported(function):
+ return not IsSupportedGraph(function).is_supported_subgraph()
+
+
+def repartition_subgraph(function):
+ """
+ Revert back to LLVM if the subgraph is not compute intensive or marked as
+ force_llvm.
+ """
+ if not IsComputeIntensiveGraph().is_graph_compute_intensive(function.body):
+ return True
+
+ if first_op_unsupported(function):
+ return True
+
+ return False
+
+
+def repartition_mrvl_subgraphs(mod):
+ """
+ Un-partition those partitions which:
+ - are not computationally intensive subgraph
+ - cannot be supported by the backend currently
+ """
+ global_vars_to_inline = [
+ gv
+ for gv in mod.get_global_vars()
+ if mod[gv].attrs and mod[gv].attrs["Compiler"] == "mrvl" and repartition_subgraph(mod[gv])
+ ]
+ return relay.transform.InlineCompilerFunctionsBoundTo(global_vars_to_inline)(mod)
+
+
+def add_attributes(mod, annotate_target_str, **kwargs):
+ """This method iterates across all Marvell partitioned functions in the
+ module and attaches attributes which are supplied by the user from the CLI.
+ Use good defaults in case a particular option is not specified. These options
+ are later accessed by codegen and are embedded into the runtime.
+
+ Parameters
+ ----------
+ mod : Module
+ The module to attach attributes to
+ kwargs : Dict[str, str]
+ Dictionary with command line options
+
+ Returns
+ -------
+ mod : module with attributes
+ """
+ working_dir = mrvl_contrib.get_working_dir()
+
+ if "mattr" in kwargs:
+ base_opts_str = kwargs.get("mattr")
+
+ # Set defaults to options if explicit command line option is not given
+ if "arch" not in base_opts_str:
+ base_opts_str = f"{base_opts_str} -arch=mlip"
+
+ if "quantize" not in base_opts_str:
+ base_opts_str = f"{base_opts_str} -quantize=fp16"
+
+ if "wb_pin_ocm" not in base_opts_str:
+ base_opts_str = f"{base_opts_str} -wb_pin_ocm=0"
+
+ else:
+ base_opts_str = "-arch=mlip -quantize=fp16 -wb_pin_ocm=0"
+
+ if "num_tiles" in kwargs:
+ base_opts_str = f"{base_opts_str} -num_tiles={kwargs.get('num_tiles')}"
+ elif "num_tiles" not in base_opts_str:
+ base_opts_str = f"{base_opts_str} -num_tiles=8"
+
+ for var in mod.get_global_vars():
+ func_name = var.name_hint
+ func = mod[func_name]
+
+ if annotate_target_str in func_name:
+ func = func.with_attr("working_dir", working_dir)
+ func = func.with_attr("compiler_opts_string", base_opts_str)
+ mod.update_func(var, func)
+
+ return mod
+
+
+def is_valid_batch_size(batch_size):
+ if isinstance(batch_size, type(relay.Any())):
+ return False
+ elif batch_size > 8:
+ return False
+ else:
+ return True
+
+
+def mrvl_register_conv2d_attr_funcs_for_convert_layout():
+ """register the conv2d attr func(s) to convert op layout"""
+ # reset first in order to register & use a new nn.conv2d convert layout function
+ relay.op.get("nn.conv2d").reset_attr("FTVMConvertOpLayout")
+
+ @tvm.ir.register_op_attr("nn.conv2d", "FTVMConvertOpLayout")
+ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
+ if not is_valid_batch_size(tinfos[0].shape[0]):
+ return relay.nn.conv2d(*inputs, **attrs)
+ new_attrs = dict(attrs)
+ weight_info_const = tinfos[1]
+ new_attrs["channels"] = weight_info_const.shape[0]
+ desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
+ new_attrs["data_layout"] = desired_data_layout
+ new_attrs["kernel_layout"] = desired_kernel_layout
+ new_attrs["out_layout"] = desired_data_layout
+ return relay.nn.conv2d(*inputs, **new_attrs)
+
+ return convert_conv2d
+
+
+def mrvl_register_max_pool2d_attr_funcs_for_convert_layout():
+ """register the max_pool2d attr func(s) to convert op layout"""
+ # reset first in order to register & use a new nn.max_pool2d convert layout function
+ relay.op.get("nn.max_pool2d").reset_attr("FTVMConvertOpLayout")
+
+ @tvm.ir.register_op_attr("nn.max_pool2d", "FTVMConvertOpLayout")
+ def convert_max_pool2d(attrs, inputs, tinfos, desired_layouts):
+ if not is_valid_batch_size(tinfos[0].shape[0]):
+ return relay.nn.max_pool2d(*inputs, **attrs)
+ new_attrs = dict(attrs)
+ new_attrs["layout"] = str(desired_layouts[0])
+ new_attrs["out_layout"] = str(desired_layouts[0])
+ return relay.nn.max_pool2d(*inputs, **new_attrs)
+
+ return convert_max_pool2d
+
+
+def mrvl_register_avg_pool2d_attr_funcs_for_convert_layout():
+ """register the avg_pool2d attr func(s) to convert op layout"""
+ # reset first in order to register& use a new nn.avg_pool2d convert layout function
+ relay.op.get("nn.avg_pool2d").reset_attr("FTVMConvertOpLayout")
+
+ @tvm.ir.register_op_attr("nn.avg_pool2d", "FTVMConvertOpLayout")
+ def convert_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
+ if (tinfos[0].shape[0] != 1) and not isinstance(tinfos[0].shape[0], type(relay.Any())):
+ return relay.nn.avg_pool2d(*inputs, **attrs)
+ new_attrs = dict(attrs)
+ new_attrs["layout"] = str(desired_layouts[0])
+ new_attrs["out_layout"] = str(desired_layouts[0])
+ return relay.nn.avg_pool2d(*inputs, **new_attrs)
+
+ return convert_avg_pool2d
+
+
+def mrvl_register_global_avg_pool2d_attr_funcs_for_convert_layout():
+ """register the global_avg_pool2d attr func(s) to convert op layout"""
+ # reset first in order to register& use a new nn.global_avg_pool2d convert layout function
+ relay.op.get("nn.global_avg_pool2d").reset_attr("FTVMConvertOpLayout")
+
+ @tvm.ir.register_op_attr("nn.global_avg_pool2d", "FTVMConvertOpLayout")
+ def convert_global_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
+ if (tinfos[0].shape[0] != 1) and not isinstance(tinfos[0].shape[0], type(relay.Any())):
+ return relay.nn.global_avg_pool2d(*inputs, **attrs)
+ new_attrs = dict(attrs)
+ new_attrs["layout"] = str(desired_layouts[0])
+ new_attrs["out_layout"] = str(desired_layouts[0])
+ return relay.nn.global_avg_pool2d(*inputs, **new_attrs)
+
+ return convert_global_avg_pool2d
+
+
+@register_pattern_table("mrvl")
+def mrvl_pattern_table():
+ """Get the Mrvl pattern table."""
+
+ def conv2d_nhwc2nhwc_pattern():
+ """Create a convolution-2d pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the convolution-2d pattern.
+ """
+
+ def conv2d_base_pattern(pattern):
+ pattern = is_op("nn.conv2d")(pattern, is_constant())
+ pattern = pattern.optional(
+ lambda x: (is_op("nn.bias_add")(x, is_constant()) | is_op("add")(x, is_constant()))
+ )
+
+ def conv2d_no_batchnorm(pattern):
+ # conv + [add] + [relu]
+ pattern1 = is_activation(pattern)
+ return pattern1
+
+ def conv2d_batchnorm(pattern):
+ pattern2 = is_op("nn.batch_norm")(
+ pattern, is_constant(), is_constant(), is_constant(), is_constant()
+ )
+ pattern2 = is_tuple_get_item(pattern2, 0)
+ pattern2 = is_activation(pattern2)
+ return pattern2
+
+ pattern1 = conv2d_no_batchnorm(pattern)
+ pattern2 = conv2d_batchnorm(pattern)
+
+ return pattern1 | pattern2
+
+ pad = is_op("nn.pad")(wildcard(), wildcard())
+ pad = conv2d_base_pattern(pad)
+ no_pad = wildcard()
+ no_pad = conv2d_base_pattern(no_pad)
+
+ return pad | no_pad
+
+ def sum2d_pattern():
+ """Create a sum2d pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the sum2d pattern.
+ """
+ pattern = is_op("add")(wildcard(), wildcard())
+ pattern = is_activation(pattern)
+ return pattern
+
+ def concat_pattern():
+ """Create a concat pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the concat pattern.
+ """
+ pattern = is_op("concatenate")(is_tuple(None))
+ return pattern
+
+ def fc_pattern():
+ """Create a fc (fully-connected) pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the fc pattern.
+ """
+ pattern = is_op("nn.dense")(wildcard(), is_constant())
+ pattern = pattern.optional(
+ lambda x: (is_op("nn.bias_add")(x, is_constant()) | is_op("add")(x, is_constant()))
+ )
+ pattern = is_activation(pattern)
+
+ return pattern
+
+ def maxpool2d_pattern():
+ """Create a maxpool2d pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the maxpool2d pattern.
+ """
+
+ def maxpool2d_base_pattern(pattern):
+ pattern = is_op("nn.max_pool2d")(pattern)
+ return pattern
+
+ pad = is_op("nn.pad")(wildcard(), wildcard())
+ pad = maxpool2d_base_pattern(pad)
+
+ no_pad = wildcard()
+ no_pad = maxpool2d_base_pattern(no_pad)
+
+ return pad | no_pad
+
+ def avgpool2d_pattern():
+ """Create a avgpool2d pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the avgpool2d pattern.
+ """
+
+ def avgpool2d_base_pattern(pattern):
+ pattern = is_op("nn.avg_pool2d")(pattern)
+
+ return pattern
+
+ pad = is_op("nn.pad")(wildcard(), wildcard())
+ pad = avgpool2d_base_pattern(pad)
+
+ no_pad = wildcard()
+ no_pad = avgpool2d_base_pattern(no_pad)
+
+ return pad | no_pad
+
+ def globalavgpool2d_pattern():
+ """Create a globalavgpool2d pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the globalavgpool2d pattern.
+ """
+ pattern = is_op("nn.global_avg_pool2d")(wildcard())
+ return pattern
+
+ def reshape_pattern():
+ pattern = is_op("reshape")(wildcard())
+ return pattern
+
+ def batch_flatten_pattern():
+ pattern = is_op("nn.batch_flatten")(wildcard())
+ return pattern
+
+ def layout_transform_nchw2nhwc_pattern():
+ pattern = is_op("layout_transform")(is_var(), wildcard(), wildcard()).has_attr(
+ {"src_layout": "NCHW", "dst_layout": "NHWC"}
+ )
+ return pattern
+
+ def layout_transform_nhwc2nchw_to_2D_pattern():
+ # Layout_Transform + Reshape/BatchFlatten
+ transform1 = is_op("layout_transform")(wildcard()).has_attr(
+ {"src_layout": "NHWC", "dst_layout": "NCHW"}
+ )
+ pattern1 = is_op("reshape")(transform1)
+ pattern2 = is_op("nn.batch_flatten")(transform1)
+
+ return pattern1 | pattern2
+
+ def check_conv2d(extract):
+ """Check conv pattern is supported by Mrvl."""
+ call = extract
+ while isinstance(call, TupleGetItem) or (call.op.name != "nn.conv2d"):
+ if isinstance(call, TupleGetItem):
+ call = call.tuple_value
+ else:
+ call = call.args[0]
+ return conv2d_nhwc2nhwc(call)
+
+ def check_fc(extract):
+ """Check fc pattern is supported by Mrvl."""
+ call = extract
+ while call.op.name != "nn.dense":
+ call = call.args[0]
+ return fc_ni2no(call)
+
+ def check_maxpool2d(extract):
+ """Check maxpool2d pattern is supported by Mrvl."""
+ call = extract
+ while call.op.name != "nn.max_pool2d":
+ call = call.args[0]
+ return maxpool2d_nhwc2nhwc(call)
+
+ def check_avgpool2d(extract):
+ """Check avgpool2d pattern is supported by Mrvl."""
+ call = extract
+ while call.op.name != "nn.avg_pool2d":
+ call = call.args[0]
+ return avgpool2d_nhwc2nhwc(call)
+
+ def check_globalavgpool2d(extract):
+ """Check globalavgpool2d pattern is supported by Mrvl."""
+ call = extract
+ while call.op.name != "nn.global_avg_pool2d":
+ call = call.args[0]
+ return globalavgpool2d_nhwc2nhwc(call)
+
+ def check_reshape(extract):
+ call = extract
+ while call.op.name != "reshape":
+ call = call.args[0]
+ return reshape_mrvl(call)
+
+ def check_batch_flatten(extract):
+ call = extract
+ while call.op.name != "nn.batch_flatten":
+ call = call.args[0]
+ return batch_flatten_mrvl(call)
+
+ def check_layout_transform_nchw2nhwc(extract):
+ call = extract
+ while call.op.name != "layout_transform":
+ call = call.args[0]
+ return layout_transform_nchw2nhwc(call)
+
+ def check_layout_transform_nhwc2nchw_2D(extract):
+ call = extract
+ if call.op.name == "reshape" or call.op.name == "nn.batch_flatten":
+ call = call.args[0]
+ if call.op.name == "layout_transform":
+ if call.attrs.src_layout == "NHWC" and call.attrs.dst_layout == "NCHW":
+ return True
+ return False
+
+ def check_sum2d(extract):
+ """Check sum2d pattern is supported by Mrvl."""
+ call = extract
+ while call.op.name != "add":
+ call = call.args[0]
+ return sum2d(call)
+
+ def check_concat(extract):
+ """Check concat pattern is supported by Mrvl."""
+ call = extract
+ while call.op.name != "concatenate":
+ call = call.args[0]
+ return concat(call)
+
+ return [
+ ("mrvl.conv2d_nhwc2nhwc", conv2d_nhwc2nhwc_pattern(), check_conv2d),
+ ("mrvl.fc_ni2no", fc_pattern(), check_fc),
+ ("mrvl.maxpool2d_nhwc2nhwc", maxpool2d_pattern(), check_maxpool2d),
+ ("mrvl.avgpool2d_nhwc2nhwc", avgpool2d_pattern(), check_avgpool2d),
+ ("mrvl.globalavgpool2d_nhwc2nhwc", globalavgpool2d_pattern(), check_globalavgpool2d),
+ ("mrvl.sum2d", sum2d_pattern(), check_sum2d),
+ ("mrvl.concat", concat_pattern(), check_concat),
+ (
+ "mrvl.layout_transform_nhwc2nchw_reshape",
+ layout_transform_nhwc2nchw_to_2D_pattern(),
+ check_layout_transform_nhwc2nchw_2D,
+ ),
+ (
+ "mrvl.layout_transform_nchw2nhwc",
+ layout_transform_nchw2nhwc_pattern(),
+ check_layout_transform_nchw2nhwc,
+ ),
+ ("mrvl.reshape", reshape_pattern(), check_reshape),
+ ("mrvl.batch_flatten", batch_flatten_pattern(), check_batch_flatten),
+ ]
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("nn.conv2d", "target.mrvl")
+def conv2d_nhwc2nhwc(expr):
+ """Check if the external Mrvl codegen for conv2d_nhwc2nhwc should be used."""
+ attrs, args = expr.attrs, expr.args
+ if attrs.data_layout != "NHWC":
+ return False
+ if attrs.out_dtype != "float32" and attrs.out_dtype != "":
+ return False
+ data_type = args[0].checked_type
+ if (
+ (len(data_type.shape) != 4)
+ or not is_valid_batch_size(data_type.shape[0])
+ or (data_type.dtype not in ["float32"])
+ ):
+ return False
+ kernel_typ = args[1].checked_type
+ if (len(kernel_typ.shape) != 4) or (kernel_typ.dtype not in ["float32"]):
+ return False
+
+ is_depthwise = is_depthwise_conv2d(
+ data_type.shape,
+ attrs["data_layout"],
+ kernel_typ.shape,
+ attrs["kernel_layout"],
+ attrs["groups"],
+ )
+ if is_depthwise:
+ # Mrvl support grouped conv only for groups == ch
+ return bool(attrs.groups == kernel_typ.shape[0])
+ if attrs.groups != 1 and not is_depthwise:
+ return False
+ return True
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("add", "target.mrvl")
+def sum2d(expr):
+ """Check if the external Mrvl codegen for sum2d should be used."""
+ arg0 = expr.args[0]
+
+ # - need to further checking if the call_func of arg0 is not nn.conv2d nor nn.dense
+ if (
+ isinstance(arg0, Call)
+ and isinstance(arg0.op, tvm.ir.Op)
+ and arg0.op.name in ["nn.conv2d", "nn.dense"]
+ ):
+ return False
+
+ # - need to further checking if dimension of input or output tensor is 4
+ data_type = arg0.checked_type
+ if (
+ (len(data_type.shape) != 4)
+ or not is_valid_batch_size(data_type.shape[0])
+ or (data_type.dtype not in ["float32"])
+ ):
+ return False
+
+ return True
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("concatenate", "target.mrvl")
+def concat(expr):
+ """Check if the external Mrvl codegen for concat should be used."""
+ attrs, args = expr.attrs, expr.args
+ arg0 = args[0]
+ assert not isinstance(arg0, Call)
+
+ # check data types for both inputs
+ # - only support 4-dimension input tensors in NHWC
+ # - only support batch size is 1
+ data_type_a = arg0.checked_type.fields[0]
+ data_type_b = arg0.checked_type.fields[1]
+ if (
+ (len(data_type_a.shape) != 4)
+ or (len(data_type_b.shape) != 4)
+ or (data_type_a.shape[0] != 1)
+ or (data_type_b.shape[0] != 1)
+ or (data_type_a.dtype not in ["float32"])
+ or (data_type_b.dtype not in ["float32"])
+ ):
+ return False
+
+ for data_type in arg0.checked_type.fields:
+ if (
+ (len(data_type.shape) != 4)
+ or (data_type.shape[0] != 1)
+ or (data_type.dtype not in ["float32"])
+ ):
+ return False
+
+ if attrs["axis"] != 3:
+ return False
+
+ return True
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("nn.dense", "target.mrvl")
+def fc_ni2no(expr):
+ """Check if the external Mrvl codegen for fc_ni2no should be used."""
+ attrs, args = expr.attrs, expr.args
+ data_type = args[0].checked_type
+ if data_type.dtype not in ["float32"]:
+ return False
+ kernel_typ = args[1].checked_type
+ if (len(kernel_typ.shape) != 2) or (kernel_typ.dtype not in ["float32"]):
+ return False
+ if attrs.out_dtype != "float32" and attrs.out_dtype != "":
+ return False
+ return True
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("nn.max_pool2d", "target.mrvl")
+def maxpool2d_nhwc2nhwc(expr):
+ """Check if the external Mrvl codegen for maxpool2d_nhwc2nhwc should be used."""
+ attrs, args = expr.attrs, expr.args
+ if attrs.layout != "NHWC":
+ return False
+ data_type = args[0].checked_type
+ if (
+ (len(data_type.shape) != 4)
+ or not is_valid_batch_size(data_type.shape[0])
+ or (data_type.dtype not in ["float32"])
+ ):
+ return False
+ return True
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("nn.avg_pool2d", "target.mrvl")
+def avgpool2d_nhwc2nhwc(expr):
+ """Check if the external Mrvl codegen for avgpool2d_nhwc2nhwc should be used."""
+ attrs, args = expr.attrs, expr.args
+ if attrs.layout != "NHWC":
+ return False
+ data_type = args[0].checked_type
+ if (
+ (len(data_type.shape) != 4)
+ or ((data_type.shape[0] != 1) and not isinstance(data_type.shape[0], type(relay.Any())))
+ or (data_type.dtype not in ["float32"])
+ ):
+ return False
+ return True
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.mrvl")
+def globalavgpool2d_nhwc2nhwc(expr):
+ """Check if the external Mrvl codegen for globalavgpool2d_nhwc2nhwc should be used."""
+ attrs, args = expr.attrs, expr.args
+ if attrs.layout != "NHWC":
+ return False
+ data_type = args[0].checked_type
+ if not (len(data_type.shape) == 4 or len(data_type.shape) == 2):
+ return False
+ if (
+ (len(data_type.shape) != 4)
+ or ((data_type.shape[0] != 1) and not isinstance(data_type.shape[0], type(relay.Any())))
+ or (data_type.dtype not in ["float32"])
+ ):
+ return False
+ return True
+
+
+@tvm.ir.register_op_attr("reshape", "target.mrvl")
+def reshape_mrvl(expr):
+ """Check if the external Mrvl codegen for reshape should be used."""
+ if expr.op.name != "reshape":
+ return False
+ else:
+ data_type = expr.checked_type
+ if not (len(data_type.shape) == 4 or len(data_type.shape) == 2):
+ return False
+
+ args = expr.args
+ data_type = args[0].checked_type
+ return True
+
+
+@tvm.ir.register_op_attr("nn.batch_flatten", "target.mrvl")
+def batch_flatten_mrvl(expr):
+ """Check if the external Mrvl codegen for batch_flatten should be used."""
+ if expr.op.name != "nn.batch_flatten":
+ return False
+ else:
+ data_type = expr.checked_type
+ if len(data_type.shape) != 2:
+ return False
+
+ args = expr.args
+ data_type = args[0].checked_type
+
+ if not (len(data_type.shape) == 4 or len(data_type.shape) == 2):
+ return False
+
+ return True
+
+
+# register a helper function to indicate that the given operator can be supported by Mrvl.
+@tvm.ir.register_op_attr("layout_transform", "target.mrvl")
+def layout_transform_nchw2nhwc(expr):
+ """Check if the external Mrvl codegen for Layout Transform should be used."""
+ attrs, args = expr.attrs, expr.args
+ if attrs.src_layout != "NCHW":
+ return False
+ if attrs.dst_layout != "NHWC":
+ return False
+ data_type = args[0].checked_type
+ if data_type.dtype not in ["float32"]:
+ return False
+ return True
+
+
+class RemoveDropout(ExprMutator):
+ """Removes all nn.dropout from an expr."""
+
+ def visit_tuple_getitem(self, op):
+ visit = super().visit_tuple_getitem(op)
+ if visit.index != 0:
+ return visit
+ if (
+ isinstance(visit.tuple_value, Call)
+ and visit.tuple_value.op.name == "nn.dropout"
+ and visit.index == 0
+ ):
+ # skip nn.dropout call and return arg0 instead
+ return visit.tuple_value.args[0]
+ return visit
+
+
+@relay.transform.function_pass(opt_level=0)
+class MrvlRemoveDropoutPass:
+ """Removes Dropouts."""
+
+ def transform_function(self, func, mod, _):
+ """call RemoveDropout func."""
+ return RemoveDropout().visit(func)
+
+
+class RemoveCopy(ExprMutator):
+ """
+ Delete Copy expression
+ """
+
+ def visit_call(self, call):
+ visit = super().visit_call(call)
+ if visit.op.name in ["copy"]:
+ return visit.args[0]
+ return visit
+
+
+@relay.transform.function_pass(opt_level=0)
+class MrvlRemoveCopyPass:
+ """Removes Copy."""
+
+ def transform_function(self, func, mod, _):
+ """call RemoveCopy func."""
+ return RemoveCopy().visit(func)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index ccad989c33..d59aa964f9 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -989,6 +989,9 @@ requires_ethosu = Feature("ethosu", "Arm(R) Ethos(TM)-U", cmake_flag="USE_ETHOSU
# Mark a test as requiring libtorch to run
requires_libtorch = Feature("libtorch", "LibTorch", cmake_flag="USE_LIBTORCH")
+# Mark a test as requiring the MRVL Library
+requires_mrvl = Feature("mrvl", "Marvell", cmake_flag="USE_MRVL")
+
# Mark a test as requiring Hexagon to run
requires_hexagon = Feature(
"hexagon",
diff --git a/src/relay/backend/contrib/mrvl/codegen.cc b/src/relay/backend/contrib/mrvl/codegen.cc
new file mode 100644
index 0000000000..527b53acf4
--- /dev/null
+++ b/src/relay/backend/contrib/mrvl/codegen.cc
@@ -0,0 +1,1361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/mrvl/codegen.cc
+ * \brief Marvell MLIP specific API
+ */
+
+#include <stdio.h>
+#include <tvm/ir/module.h>
+#include <tvm/relay/type.h>
+#include <tvm/tir/analysis.h>
+
+#include <iomanip>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <regex>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../../../../runtime/contrib/json/json_node.h"
+#include "../../../../support/base64.h"
+#include "../../../qnn/utils.h"
+#include "../../utils.h"
+#include "../codegen_json/codegen_json.h"
+
+namespace tvm {
+namespace relay {
+
+namespace contrib {
+namespace mrvl {
+
+using namespace backend;
+
+struct const_struct {
+ std::string name;
+ std::string shape;
+ std::string dtype;
+ std::string min;
+ std::string max;
+ std::string data_base64;
+};
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+
+/*!
+ * \brief Generates an MrvlModule from a relay expression. This "compilation"
+ * does not require Mrvl driver since the actual conversion using Mrvl APIs is
+ * deferred until creation of the runtime. This step simply serializes the
+ * relay program into a JSON string.
+ */
+class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
+ public:
+ /*!
+ * \brief Constructor
+ *
+ * \param symbol The symbol that represents the graph being converted.
+ * \param expr The Relay expression to be converted to the JSON form.
+ */
+ MrvlJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {
+ layer_name_ = symbol;
+ }
+
+ /*! \brief Return the required params. */
+ Array<String> GetParams() const {
+ tvm::runtime::Array<tvm::runtime::String> base_params = JSONSerializer::const_names();
+ Array<String> mrvl_params;
+ for (size_t idx = 0; idx < base_params.size(); idx++) {
+ mrvl_params.push_back(base_params[idx]);
+ }
+ for (size_t idx = 0; idx < batch_norm_params_.size(); idx++) {
+ mrvl_params.push_back(batch_norm_params_[idx]);
+ }
+ return mrvl_params;
+ }
+
+ template <typename T>
+ std::string FloatToString(T val, size_t precision = 17) {
+ // Method to serialize floating point values (double, float)
+ // to a string with required precision.
+ std::ostringstream s;
+ s.precision(precision);
+ s << val;
+ return s.str();
+ }
+
+ /*! \brief Return the Const Json Strings. */
+ std::string GetConstJSONString() {
+ std::string json_string;
+ auto names = const_names();
+ auto map = const_name_to_constant();
+ std::vector<const_struct> const_info_vec;
+ for (auto name_const : names) {
+ const_struct a;
+ std::string const_string = name_const;
+ auto arr = map[const_string];
+ a.name = const_string;
+ a.dtype = "float" + std::to_string(static_cast<int>(arr->dtype.bits));
+ std::string shape;
+ shape += "[ ";
+
+ int ndim = arr->ndim;
+ if (ndim == 1 || ndim == 3) {
+ shape += "1, ";
+ }
+ int tot_dim = 1;
+ for (int i = 0; i < ndim; i++) {
+ tot_dim *= arr->shape[i];
+ shape += std::to_string(arr->shape[i]);
+ if (i != ndim - 1) shape += ", ";
+ }
+ shape += " ]";
+ a.shape = shape;
+ int size = (arr->dtype.bits + 7) / 8;
+ int num_bytes = tot_dim * size;
+ std::string blob;
+ dmlc::MemoryStringStream mstrm(&blob);
+ support::Base64OutStream b64strm(&mstrm);
+ b64strm.Write(arr->data, num_bytes);
+ b64strm.Finish();
+ a.data_base64 = blob;
+ // Populate min and max
+ float min_val = std::numeric_limits<float>::infinity();
+ float max_val = -min_val;
+ for (int i = 0; i < tot_dim; i++) {
+ auto val = static_cast<float*>(arr->data)[i];
+ if (val > max_val) max_val = val;
+ if (val < min_val) min_val = val;
+ }
+
+ a.min = FloatToString<float>(min_val);
+ a.max = FloatToString<float>(max_val);
+
+ const_info_vec.push_back(a);
+ }
+
+ json_string += "{\n";
+ for (unsigned int i = 0; i < const_info_vec.size(); i++) {
+ auto a = const_info_vec[i];
+ json_string += "\t\"" + a.name + "\": {\n";
+ json_string += "\t\"shape\": " + a.shape + ",\n";
+ json_string += "\t\"dtype\": \"" + a.dtype + "\"" + ",\n";
+ json_string += "\t\"min\": \"" + a.min + "\"" + ",\n";
+ json_string += "\t\"max\": \"" + a.max + "\"" + ",\n";
+ json_string += "\t\"data_base64\": \"" + a.data_base64 + "\"\n";
+ if (i == const_info_vec.size() - 1) {
+ json_string += "\t}\n";
+ } else {
+ json_string += "\t},\n";
+ }
+ }
+ json_string += "}\n";
+ return json_string;
+ }
+
+ protected:
+ /*!
+ * \brief A series of operators that form a composite
+ * convolution. Supports both nn.conv2d and qnn.conv2d.
+ */
+ struct CompositeConvNode {
+ const CallNode* pad = nullptr;
+ const CallNode* conv = nullptr;
+ const CallNode* add = nullptr;
+ const CallNode* batch_norm = nullptr;
+ const CallNode* activation = nullptr;
+ };
+
+ /*!
+ * \brief A series of operators that form a composite
+ * sum2d.
+ */
+ struct CompositeSum2DNode {
+ const CallNode* add = nullptr;
+ const CallNode* activation = nullptr;
+ };
+
+ /*!
+ * \brief A series of operators that form a composite
+ * maxpool or avgpool. Supports both nn.max_pool2d and qnn.conv2d.
+ */
+ struct CompositePoolNode {
+ const CallNode* pad = nullptr;
+ const CallNode* pool = nullptr;
+ };
+
+ /*!
+ * \brief A series of operators that form a composite
+ * concat.
+ */
+ struct CompositeConcatNode {
+ const CallNode* concat = nullptr;
+ };
+
+ /*!
+ * \brief A series of operators that form a reshape node.
+ */
+ struct CompositeReshapeNode {
+ const CallNode* reshape = nullptr;
+ };
+
+ /*!
+ * \brief A series of operators that form a transform reshape node.
+ */
+ struct CompositeLayoutTransformReshapeNode {
+ const CallNode* transform = nullptr;
+ const CallNode* reshape = nullptr;
+ };
+
+ /*!
+ * \brief A series of operators that form a batch flatten node.
+ */
+ struct CompositeBatchFlattenNode {
+ const CallNode* batch_flatten = nullptr;
+ };
+
+ /*!
+ * \brief A series of operators that form a composite
+ * fc layer. Supports both nn.fc_ni2no and qnn.fc_ni2no.
+ */
+ struct CompositeFcNode {
+ const CallNode* fc = nullptr;
+ const CallNode* add = nullptr;
+ const CallNode* activation = nullptr;
+ };
+
+ /*!
+ * \brief Visit call nodes and generate appropriate JSON node.
+ *
+ * \param cn The current call node.
+ * \return A list of graph entry nodes.
+ */
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) override {
+ const auto* op_node = cn->op.as<OpNode>();
+ if (op_node) {
+ // handle certain op node types specially
+ String op_name = op_node->name;
+ bool handle_by_mrvl = (op_name == "layout_transform" || op_name == "transpose");
+ if (!handle_by_mrvl) {
+ return JSONSerializer::VisitExpr_(cn);
+ }
+
+ // setup json attributes and then add the Mrvl Layer to JSON files
+ std::shared_ptr<JSONGraphNode> json_kernel_node;
+ json_kernel_node = CreateMrvlLayer4OpNode(cn);
+ return AddNode(json_kernel_node, GetRef<Expr>(cn));
+ }
+
+ // handle only mrvl composite functions
+ if (!cn->op.as<FunctionNode>()) {
+ LOG(FATAL) << "Mrvl JSON runtime does not support calls to " << cn->op->GetTypeKey();
+ }
+ auto fn = cn->op.as<FunctionNode>();
+ auto comp = fn->GetAttr<String>(attr::kComposite);
+ ICHECK(comp.defined()) << "Marvell-Compiler-ERROR-Internal::Illegal Mrvl composite function.";
+ const std::string name = comp.value();
+ std::shared_ptr<JSONGraphNode> json_kernel_node;
+ if (name == "mrvl.conv2d_nhwc2nhwc") {
+ json_kernel_node = CreateCompositeMrvlConv2DLayer(cn);
+ } else if (name == "mrvl.fc_ni2no") {
+ json_kernel_node = CreateCompositeMrvlFcLayer(cn);
+ } else if (name == "mrvl.maxpool2d_nhwc2nhwc") {
+ json_kernel_node = CreateCompositeMrvlMaxpool2DLayer(cn);
+ } else if (name == "mrvl.avgpool2d_nhwc2nhwc") {
+ json_kernel_node = CreateCompositeMrvlAvgpool2DLayer(cn);
+ } else if (name == "mrvl.globalavgpool2d_nhwc2nhwc") {
+ json_kernel_node = CreateCompositeMrvlGlobalAvgpool2DLayer(cn);
+ } else if (name == "mrvl.sum2d") {
+ json_kernel_node = CreateCompositeMrvlSum2DLayer(cn);
+ } else if (name == "mrvl.concat") {
+ json_kernel_node = CreateMrvlConcatLayer(cn);
+ } else if (name == "mrvl.layout_transform_nhwc2nchw_reshape") {
+ json_kernel_node = CreateMrvlLayoutTransposeReshapeLayer(cn);
+ } else if (name == "mrvl.reshape") {
+ json_kernel_node = CreateMrvlReshapeLayer(cn);
+ } else if (name == "mrvl.batch_flatten") {
+ json_kernel_node = CreateMrvlBatchFlattenLayer(cn);
+ } else {
+ LOG(FATAL) << "Unrecognized Mrvl pattern: " << name;
+ }
+ // calling codegen_json.h::AddNode()
+ return AddNode(json_kernel_node, GetRef<Expr>(cn));
+ }
+
+ private:
+ /*! \brief The symbol that represents the layer json graph. */
+ std::string layer_name_;
+ Array<String> batch_norm_params_;
+ int node_idx_{0};
+ int const_suffix_{0};
+
+ /*!
+ * \brief Extract convolution nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite convolution nodes.
+ */
+ CompositeConvNode UnpackCompositeConvolution(const CallNode* call) {
+ CompositeConvNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+ // - conv2d + [ bias_add ] + [ batch_norm + tuple.getitem(0) ] + [ relu ]
+ // Traverse composite convolution function from child to parent
+ const TupleGetItemNode* tuple_get_item_node = nullptr;
+ const CallNode* current_call = fn->body.as<CallNode>();
+ if (current_call) {
+ if (backend::IsOp(current_call, "nn.relu")) {
+ nodes.activation = current_call;
+ if (current_call->args[0].as<TupleGetItemNode>()) {
+ tuple_get_item_node = current_call->args[0].as<TupleGetItemNode>();
+ } else {
+ current_call = current_call->args[0].as<CallNode>();
+ }
+ } else {
+ ICHECK(current_call) << "Marvell-Compiler-ERROR-Internal::Downcast to CallNode failed.";
+ }
+ } else {
+ tuple_get_item_node = fn->body.as<TupleGetItemNode>();
+ }
+
+ if (tuple_get_item_node != nullptr) {
+ ICHECK(tuple_get_item_node->index == 0)
+ << "Marvell-Compiler-ERROR-Internal::(index == 0) failed for the TupleGetItem node.";
+ current_call = tuple_get_item_node->tuple.as<CallNode>();
+
+ ICHECK(backend::IsOp(current_call, "nn.batch_norm"))
+ << "Marvell-Compiler-ERROR-Internal::nn.batch_norm Op missing.";
+ nodes.batch_norm = current_call;
+ current_call = nodes.batch_norm->args[0].as<CallNode>();
+ }
+
+ ICHECK(current_call) << "Marvell-Compiler-ERROR-Internal::Downcast to CallNode failed.";
+ if (backend::IsOp(current_call, "add")) {
+ nodes.add = current_call;
+ current_call = current_call->args[0].as<CallNode>();
+ }
+
+ ICHECK(backend::IsOp(current_call, "nn.conv2d"))
+ << "Marvell-Compiler-ERROR-Internal::nn.conv2d Op missing.";
+ nodes.conv = current_call;
+ current_call = current_call->args[0].as<CallNode>();
+
+ if (current_call && backend::IsOp(current_call, "nn.pad")) {
+ nodes.pad = current_call;
+ }
+ return nodes;
+ }
+
+ /*!
+ * \brief Extract sum2d nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite sum2d nodes.
+ */
+ CompositeSum2DNode UnpackCompositeSum2D(const CallNode* call) {
+ CompositeSum2DNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+
+ const auto* current_call = fn->body.as<CallNode>();
+ if (backend::IsOp(current_call, "nn.relu")) {
+ nodes.activation = current_call;
+ current_call = current_call->args[0].as<CallNode>();
+ }
+ ICHECK(backend::IsOp(current_call, "add"))
+ << "Marvell-Compiler-ERROR-Internal::add Op missing.";
+ nodes.add = current_call;
+
+ return nodes;
+ }
+
+ /*!
+ * \brief Extract Concat nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite Concat nodes.
+ */
+ CompositeConcatNode UnpackCompositeConcat(const CallNode* call) {
+ CompositeConcatNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+
+ const auto* current_call = fn->body.as<CallNode>();
+
+ ICHECK(backend::IsOp(current_call, "concatenate"))
+ << "Marvell-Compiler-ERROR-Internal::concatenate Op missing.";
+ nodes.concat = current_call;
+
+ return nodes;
+ }
+
+ /*!
+ * \brief Extract LayoutTransposeReshape nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite layouttranspose reshape nodes.
+ */
+ CompositeLayoutTransformReshapeNode UnpackCompositeLayoutTransposeReshape(const CallNode* call) {
+ CompositeLayoutTransformReshapeNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+
+ const CallNode* current_call = fn->body.as<CallNode>();
+ ICHECK(backend::IsOp(current_call, "reshape") ||
+ backend::IsOp(current_call, "nn.batch_flatten"))
+ << "Marvell-Compiler-ERROR-Internal::Reshape/Batch_flatten Op missing.";
+ nodes.reshape = current_call;
+ current_call = current_call->args[0].as<CallNode>();
+
+ ICHECK(backend::IsOp(current_call, "layout_transform"))
+ << "Marvell-Compiler-ERROR-Internal::Layout_Transform Op missing.";
+ nodes.transform = current_call;
+ return nodes;
+ }
+
+ /*!
+ * \brief Extract Reshape nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite Reshape nodes.
+ */
+ CompositeReshapeNode UnpackCompositeReshape(const CallNode* call) {
+ CompositeReshapeNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+ const auto* current_call = fn->body.as<CallNode>();
+ ICHECK(backend::IsOp(current_call, "reshape"))
+ << "Marvell-Compiler-ERROR-Internal::reshape missing.";
+ nodes.reshape = current_call;
+ return nodes;
+ }
+
+ /*!
+ * \brief Extract Batch flatten nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite batch flatten nodes.
+ */
+ CompositeBatchFlattenNode UnpackCompositeBatchFlatten(const CallNode* call) {
+ CompositeBatchFlattenNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+ const auto* current_call = fn->body.as<CallNode>();
+ ICHECK(backend::IsOp(current_call, "nn.batch_flatten"))
+ << "Marvell-Compiler-ERROR-Internal::batch_flatten missing.";
+ nodes.batch_flatten = current_call;
+ return nodes;
+ }
+
+ /*!
+ * \brief Extract maxpool nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite maxpool nodes.
+ */
+ CompositePoolNode UnpackCompositePool(const CallNode* call, const std::string& mrvlLayerName) {
+ CompositePoolNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+
+ // Traverse composite maxpool function from child to parent
+ const auto* current_call = fn->body.as<CallNode>();
+
+ if (mrvlLayerName == "Maxpool2D") {
+ ICHECK(backend::IsOp(current_call, "nn.max_pool2d"))
+ << "Marvell-Compiler-ERROR-Internal::nn.max_pool2d Op missing.";
+ } else if (mrvlLayerName == "Avgpool2D") {
+ ICHECK(mrvlLayerName == "Avgpool2D")
+ << "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
+ ICHECK(backend::IsOp(current_call, "nn.avg_pool2d"))
+ << "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
+ } else {
+ ICHECK(mrvlLayerName == "GlobalAvgpool2D")
+ << "Marvell-Compiler-ERROR-Internal::nn.global_avg_pool2d Op missing.";
+ ICHECK(backend::IsOp(current_call, "nn.global_avg_pool2d"))
+ << "Marvell-Compiler-ERROR-Internal::nn.global_avg_pool2d Op missing.";
+ }
+ nodes.pool = current_call;
+ current_call = current_call->args[0].as<CallNode>();
+ if (current_call && backend::IsOp(current_call, "nn.pad")) {
+ nodes.pad = current_call;
+ }
+
+ return nodes;
+ }
+
+ /*!
+ * \brief Extract fc nodes from a composite function.
+ *
+ * \param call The call node of the composite function.
+ * \return Extracted composite fc nodes.
+ */
+ CompositeFcNode UnpackCompositeFc(const CallNode* call) {
+ CompositeFcNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+ const auto* current_call = fn->body.as<CallNode>();
+
+ // Traverse composite fc function from child to parent
+ if (backend::IsOp(current_call, "nn.batch_flatten")) {
+ current_call = current_call->args[0].as<CallNode>();
+ }
+ if (backend::IsOp(current_call, "nn.relu")) {
+ nodes.activation = current_call;
+ current_call = current_call->args[0].as<CallNode>();
+ }
+ if (backend::IsOp(current_call, "add")) {
+ nodes.add = current_call;
+ current_call = current_call->args[0].as<CallNode>();
+ }
+ ICHECK(backend::IsOp(current_call, "nn.dense"))
+ << "Marvell-Compiler-ERROR-Internal::nn.dense Op missing.";
+ nodes.fc = current_call;
+ return nodes;
+ }
+
+ void JsonNodeSetAttr(std::shared_ptr<JSONGraphNode> json_node, const std::string& key,
+ const std::vector<std::string>& string_vec) {
+ std::vector<dmlc::any> json_attr;
+ json_attr.emplace_back(string_vec);
+ json_node->SetAttr(key, json_attr);
+ }
+
+ void JsonNodeSetVecAttr(std::shared_ptr<JSONGraphNode> json_node, const std::string& key,
+ const std::vector<int64_t>& tvec) {
+ size_t tvec_size = tvec.size();
+ std::vector<std::string> tvec_str;
+ if (tvec_size == 4) {
+ tvec_str = {std::to_string(tvec[0]), std::to_string(tvec[1]), std::to_string(tvec[2]),
+ std::to_string(tvec[3])};
+ } else if (tvec_size == 3) {
+ tvec_str = {std::to_string(tvec[0]), std::to_string(tvec[1]), std::to_string(tvec[2])};
+ } else if (tvec_size == 2) {
+ tvec_str = {std::to_string(tvec[0]), std::to_string(tvec[1])};
+ } else {
+ tvec_str = {std::to_string(tvec[0])};
+ }
+ std::vector<dmlc::any> json_attr;
+ json_attr.emplace_back(tvec_str);
+ json_node->SetAttr(key, json_attr);
+ }
+
+ void SetMrvlLayerBatchnormAttrs(std::shared_ptr<JSONGraphNode> json_node,
+ const CallNode* cn_batchnorm) {
+ if (cn_batchnorm == nullptr) return;
+
+ SetCallNodeAttribute(json_node, cn_batchnorm);
+
+ std::vector<std::string> gamma_const_name;
+ std::vector<std::string> beta_const_name;
+ std::vector<std::string> mean_const_name;
+ std::vector<std::string> var_const_name;
+ std::string batch_norm_layout = "-O";
+
+ gamma_const_name = {layer_name_ + "_const_" + std::to_string(const_suffix_++)};
+ beta_const_name = {layer_name_ + "_const_" + std::to_string(const_suffix_++)};
+ mean_const_name = {layer_name_ + "_const_" + std::to_string(const_suffix_++)};
+ var_const_name = {layer_name_ + "_const_" + std::to_string(const_suffix_++)};
+
+ JsonNodeSetAttr(json_node, "gamma_const_name", gamma_const_name);
+ JsonNodeSetAttr(json_node, "beta_const_name", beta_const_name);
+ JsonNodeSetAttr(json_node, "mean_const_name", mean_const_name);
+ JsonNodeSetAttr(json_node, "var_const_name", var_const_name);
+ JsonNodeSetAttr(json_node, "gamma_layout", {batch_norm_layout});
+ JsonNodeSetAttr(json_node, "beta_layout", {batch_norm_layout});
+ JsonNodeSetAttr(json_node, "mean_layout", {batch_norm_layout});
+ JsonNodeSetAttr(json_node, "var_layout", {batch_norm_layout});
+ }
+
+ void SetMrvlLayerPadAttrs(std::shared_ptr<JSONGraphNode> json_node, const CallNode* cn_pad) {
+ if (cn_pad == nullptr) return;
+
+ const auto* pad_attr = cn_pad->attrs.as<PadAttrs>();
+ ICHECK(pad_attr) << "Marvell-Compiler-ERROR-Internal::Downcast to PadAttrs failed.";
+ ICHECK(cn_pad->args[1].as<ConstantNode>() == 0)
+ << "Marvell-Compiler-ERROR-Internal::padded value is non-zero.";
+ ICHECK(pad_attr->pad_mode == "constant")
+ << "Marvell-Compiler-ERROR-Internal::unsupported padding mode.";
+
+ auto p = pad_attr->pad_width;
+ // Convert to TVM layout for now, conversion to Mrvl layout takes place in runtime.
+ // Standard pad layout for TVM: top, left, bottom, right.
+ std::vector<std::string> padding = {std::to_string(p[1][0].as<IntImmNode>()->value),
+ std::to_string(p[2][0].as<IntImmNode>()->value),
+ std::to_string(p[1][1].as<IntImmNode>()->value),
+ std::to_string(p[2][1].as<IntImmNode>()->value)};
+
+ JsonNodeSetAttr(json_node, "padding", {padding});
+ }
+
+ void SetMrvlLayerCommonAttrs(std::shared_ptr<JSONGraphNode> json_node, const CallNode* cn,
+ const std::string& func_name, const std::string& mrvlLayerName,
+ const std::string& data_layout, const std::string& kernel_layout,
+ const std::string& out_layout) {
+ JsonNodeSetAttr(json_node, "layer_name", {mrvlLayerName});
+ JsonNodeSetAttr(json_node, "func_node_name", {func_name});
+ std::vector<int64_t> data_layout_vec;
+
+ auto num_inputs = GetInputNum(cn);
+ auto num_outputs = GetOutputNum(cn);
+ auto counter = num_inputs;
+ for (size_t i = 0; i < counter; i++) {
+ if (cn->args[i].as<ConstantNode>()) num_inputs--;
+ }
+
+ std::vector<int64_t> tuple_idx_vec;
+ int tuple_idx = -1;
+ if (num_inputs > 1) {
+ for (size_t in_idx = 0; in_idx < num_inputs; in_idx++) {
+ std::vector<int64_t> data_layout_vec_n;
+ GetInputTensorShapeViaArg(cn, &data_layout_vec_n, &tuple_idx, in_idx);
+ std::string attr_name = "data_layout_shape_" + std::to_string(in_idx);
+ JsonNodeSetVecAttr(json_node, attr_name, data_layout_vec_n);
+ tuple_idx_vec.push_back(tuple_idx);
+ if (in_idx == 0) {
+ JsonNodeSetVecAttr(json_node, "data_layout_shape", data_layout_vec_n);
+ }
+ }
+ } else {
+ GetInputTensorShapeViaArg(cn, &data_layout_vec, &tuple_idx, 0);
+ JsonNodeSetVecAttr(json_node, "data_layout_shape", data_layout_vec);
+ tuple_idx_vec.push_back(tuple_idx);
+ }
+ JsonNodeSetVecAttr(json_node, "from_tuple_idx", tuple_idx_vec);
+
+ if (data_layout != "") {
+ std::vector<std::string> data_layout_format_vec = {data_layout};
+ JsonNodeSetAttr(json_node, "data_layout", data_layout_format_vec);
+ }
+
+ std::vector<int64_t> out_layout_vec;
+ if (num_outputs > 1) {
+ std::vector<std::vector<int64_t>> output_layout_vec_vec;
+ GetOutputTensorShapes(cn, &output_layout_vec_vec);
+ for (size_t out_idx = 0; out_idx < num_outputs; out_idx++) {
+ std::string attr_name = "out_layout_shape_" + std::to_string(out_idx);
+ JsonNodeSetVecAttr(json_node, attr_name, output_layout_vec_vec.at(out_idx));
+ }
+ // For compatibility with backend
+ JsonNodeSetVecAttr(json_node, "out_layout_shape", output_layout_vec_vec.at(0));
+ } else {
+ GetOutputTensorShape(cn, &out_layout_vec);
+ JsonNodeSetVecAttr(json_node, "out_layout_shape", out_layout_vec);
+ }
+
+ if (kernel_layout != "") {
+ std::vector<std::string> kernel_layout_format_vec = {kernel_layout};
+ JsonNodeSetAttr(json_node, "kernel_layout", kernel_layout_format_vec);
+ }
+ if (out_layout != "") {
+ std::vector<std::string> out_layout_format_vec = {out_layout};
+ JsonNodeSetAttr(json_node, "out_layout", out_layout_format_vec);
+ }
+
+ // setup n<#>_<mrvlLayerName> as GUI node name ("func_name") in nodes JSON file
+ std::string node_id_func_name = "";
+ node_id_func_name = "n" + std::to_string(node_idx_++) + "_" + mrvlLayerName;
+
+ // - add posfix layout(s) if applicable
+ if ((data_layout != "") && (out_layout != "")) {
+ node_id_func_name += "_" + data_layout;
+ if (data_layout != out_layout) {
+ node_id_func_name += "2" + out_layout;
+ }
+ }
+
+ JsonNodeSetAttr(json_node, "func_name", {node_id_func_name});
+
+ const auto* fn = cn->op.as<FunctionNode>();
+ if (fn != nullptr) {
+ ICHECK(fn->IsInstance<FunctionNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+ auto composite = fn->GetAttr<String>(attr::kComposite);
+ ICHECK(composite.defined())
+ << "Marvell-Compiler-ERROR-Internal::Illegal Mrvl composite function.";
+ std::string composite_name = composite.value();
+ JsonNodeSetAttr(json_node, "composite_name", {composite_name});
+ }
+ }
+
+ void GetInputTensorShapeFromTuple(const CallNode* call_node_ptr, size_t index,
+ std::vector<int64_t>* tensor_shape) {
+ ICHECK(!call_node_ptr->args.empty());
+ const TensorTypeNode* tensor_type = nullptr;
+ if (call_node_ptr->args[0].as<CallNode>()) {
+ const auto* arg0 = call_node_ptr->args[0].as<CallNode>();
+ tensor_type = arg0->checked_type_.as<TensorTypeNode>();
+ } else if (call_node_ptr->args[0].as<VarNode>()) {
+ const auto* arg0 = call_node_ptr->args[0].as<VarNode>();
+ ICHECK((arg0 != nullptr) && arg0->IsInstance<VarNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to VarNode failed.";
+ tensor_type = arg0->checked_type_.as<TensorTypeNode>();
+ const TupleTypeNode* tuple_type = arg0->checked_type_.as<TupleTypeNode>();
+ if (tuple_type) {
+ tensor_type = tuple_type->fields[index].as<TensorTypeNode>();
+ }
+ } else {
+ LOG(INFO) << "TVM Mrvl runtime does not support calls to "
+ << call_node_ptr->args[0]->GetTypeKey();
+ }
+
+ ICHECK((tensor_type != nullptr) && tensor_type->IsInstance<TensorTypeNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to TensorTypeNode failed.";
+ for (IndexExpr dim_val : tensor_type->shape) {
+ tensor_shape->push_back(*(tir::as_const_int(dim_val)));
+ }
+ }
+
+ size_t GetInputNum(const CallNode* call_node_ptr) {
+ size_t num_inputs = call_node_ptr->args.size();
+ ICHECK(!call_node_ptr->args.empty());
+ const TupleGetItemNode* tuple_get_item_node = call_node_ptr->args[0].as<TupleGetItemNode>();
+ const TensorTypeNode* tensor_type = nullptr;
+ if (tuple_get_item_node) {
+ tensor_type = tuple_get_item_node->checked_type().as<TensorTypeNode>();
+ } else if (call_node_ptr->args[0].as<CallNode>()) {
+ num_inputs = call_node_ptr->args.size();
+ } else if (call_node_ptr->args[0].as<VarNode>()) {
+ const auto* arg_0 = call_node_ptr->args[0].as<VarNode>();
+ ICHECK((arg_0 != nullptr) && arg_0->IsInstance<VarNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to VarNode failed.";
+ tensor_type = arg_0->checked_type_.as<TensorTypeNode>();
+ if (tensor_type == nullptr) {
+ const TupleTypeNode* tuple_type = arg_0->checked_type_.as<TupleTypeNode>();
+ if (tuple_type) {
+ num_inputs = tuple_type->fields.size();
+ }
+ }
+ } else {
+ LOG(INFO) << "TVM Mrvl runtime does not support calls to "
+ << call_node_ptr->args[0]->GetTypeKey();
+ }
+ return num_inputs;
+ }
+
+ size_t GetOutputNum(const CallNode* call_node_ptr) {
+ ICHECK(call_node_ptr != nullptr);
+ const TupleTypeNode* tuple_type = call_node_ptr->checked_type_.as<TupleTypeNode>();
+ if (tuple_type) {
+ return tuple_type->fields.size();
+ }
+ // If output isn't a tuple, there is a single output
+ return 1;
+ }
+
+ void GetInputTensorShapeViaArg(const CallNode* call_node_ptr, std::vector<int64_t>* tensor_shape,
+ int* tuple_index, size_t n) {
+ *tuple_index = -1;
+ ICHECK(!call_node_ptr->args.empty());
+ const TensorTypeNode* tensor_type = nullptr;
+ const TupleGetItemNode* tuple_get_item_node = call_node_ptr->args[n].as<TupleGetItemNode>();
+ if (tuple_get_item_node) {
+ *tuple_index = tuple_get_item_node->index;
+ tensor_type = tuple_get_item_node->checked_type().as<TensorTypeNode>();
+ } else if (call_node_ptr->args[n].as<CallNode>()) {
+ const auto* arg_n = call_node_ptr->args[n].as<CallNode>();
+ tensor_type = arg_n->checked_type().as<TensorTypeNode>();
+ } else if (call_node_ptr->args[n].as<VarNode>()) {
+ const auto* arg_n = call_node_ptr->args[n].as<VarNode>();
+ ICHECK((arg_n != nullptr) && arg_n->IsInstance<VarNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to VarNode failed.";
+ tensor_type = arg_n->checked_type().as<TensorTypeNode>();
+ if (tensor_type == nullptr) {
+ const TupleTypeNode* tuple_type = arg_n->checked_type().as<TupleTypeNode>();
+ if (tuple_type) {
+ tensor_type = tuple_type->fields[n].as<TensorTypeNode>();
+ }
+ }
+ } else {
+ LOG(INFO) << "TVM Mrvl runtime does not support calls to "
+ << call_node_ptr->args[n]->GetTypeKey();
+ }
+
+ ICHECK((tensor_type != nullptr) && tensor_type->IsInstance<TensorTypeNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to TensorTypeNode failed.";
+ // use only data types supported by json.h (e.g., int or int64_t or size_t)
+ for (IndexExpr dim_val : tensor_type->shape) {
+ tensor_shape->push_back(*(tir::as_const_int(dim_val)));
+ }
+ }
+
+ void GetInputTensorShapeViaArg0(const CallNode* call_node_ptr,
+ std::vector<int64_t>* tensor_shape) {
+ int tuple_idx = -1;
+ GetInputTensorShapeViaArg(call_node_ptr, tensor_shape, &tuple_idx, 0);
+ }
+
+ void GetTensorShape(const VarNode* var_node_ptr, std::vector<int64_t>* tensor_shape) {
+ ICHECK((var_node_ptr != nullptr) && var_node_ptr->IsInstance<VarNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to VarNode failed.";
+ const TensorTypeNode* tensor_type = var_node_ptr->checked_type_.as<TensorTypeNode>();
+ ICHECK((tensor_type != nullptr) && tensor_type->IsInstance<TensorTypeNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to TensorTypeNode failed.";
+ // use only data types supported by json.h (e.g., int or int64_t or size_t)
+ for (IndexExpr dim_val : tensor_type->shape) {
+ tensor_shape->push_back(*(tir::as_const_int(dim_val)));
+ }
+ }
+
+ void GetOutputTensorShape(const CallNode* call_node_ptr, std::vector<int64_t>* tensor_shape) {
+ ICHECK(call_node_ptr != nullptr);
+ const TensorTypeNode* tensor_type = call_node_ptr->checked_type_.as<TensorTypeNode>();
+ ICHECK((tensor_type != nullptr) && tensor_type->IsInstance<TensorTypeNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to TensorTypeNode failed.";
+ for (IndexExpr dim_val : tensor_type->shape) {
+ tensor_shape->push_back(*(tir::as_const_int(dim_val)));
+ }
+ }
+
+ void GetOutputTensorShapes(const CallNode* call_node_ptr,
+ std::vector<std::vector<int64_t>>* tensor_shapes) {
+ ICHECK(call_node_ptr != nullptr);
+
+ const TupleTypeNode* tuple_type = call_node_ptr->checked_type_.as<TupleTypeNode>();
+ ICHECK((tuple_type != nullptr) && tuple_type->IsInstance<TupleTypeNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to TupleTypeNode failed.";
+ for (auto field : tuple_type->fields) {
+ const TensorTypeNode* tensor_type = field.as<TensorTypeNode>();
+ ICHECK((tensor_type != nullptr) && tensor_type->IsInstance<TensorTypeNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to TensorTypeNode failed.";
+ // use only data types supported by json.h (e.g., int or int64_t or size_t)
+ std::vector<int64_t> tensor_shape;
+ for (IndexExpr dim_val : tensor_type->shape) {
+ tensor_shape.push_back(*(tir::as_const_int(dim_val)));
+ }
+ tensor_shapes->push_back(tensor_shape);
+ }
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite convolution.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateCompositeMrvlConv2DLayer(const CallNode* cn) {
+ CompositeConvNode nodes = UnpackCompositeConvolution(cn);
+ const auto* conv_attrs = nodes.conv->attrs.as<Conv2DAttrs>();
+ ICHECK(conv_attrs) << "Marvell-Compiler-ERROR-Internal::Downcast to Conv2DAttrs failed.";
+
+ std::string name;
+ std::string mrvlLayerName = "";
+ std::string data_layout;
+ std::string kernel_layout;
+ std::string out_layout;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ // data input tensor
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ // weight tensor
+ inputs.push_back(VisitExpr(nodes.conv->args[1])[0]);
+ if (nodes.add) {
+ // bias tensor
+ inputs.push_back(VisitExpr(nodes.add->args[1])[0]);
+ }
+ if (nodes.batch_norm) {
+ // get gamma, beta, mean, and var of batch-norm
+ for (size_t const_idx = 0; const_idx <= 3; const_idx++) {
+ size_t arg_idx = const_idx + 1;
+ ICHECK(nodes.batch_norm->args[arg_idx].as<ConstantNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to ConstantNode failed.";
+ auto n = nodes.batch_norm->args[arg_idx];
+ auto it = memo_.find(n);
+ if (it != memo_.end()) {
+ memo_.erase(n);
+ }
+ inputs.push_back(VisitExpr(n)[0]);
+ }
+ }
+
+ // Distinguish between normal and depth-wise convolution
+ data_layout = conv_attrs->data_layout;
+ kernel_layout = conv_attrs->kernel_layout;
+ out_layout = conv_attrs->out_layout;
+ int groups = conv_attrs->groups;
+ if ((groups != 1) && conv_attrs->channels.defined() &&
+ tvm::tir::ExprDeepEqual()(conv_attrs->channels, conv_attrs->groups)) {
+ name = "nn.dw_conv2d_nhwc2nhwc";
+ mrvlLayerName = "Conv2D";
+ if (conv_attrs->groups == 1) {
+ ICHECK(kernel_layout == "IHWO")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Kernel layout must be IHWO, has the module been pre-processed correctly?";
+ }
+ } else {
+ name = "nn.conv2d_nhwc2nhwc";
+ mrvlLayerName = "Conv2D";
+ ICHECK(data_layout == "NHWC")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Data layout must be NHWC, has the module been pre-processed correctly?";
+ ICHECK(kernel_layout == "OHWI")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Kernel layout must be OHWI, has the module been pre-processed correctly?";
+ ICHECK(out_layout == "NHWC")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Out layout must be NHWC, has the module been pre-processed correctly?";
+ }
+
+ // add json node attributes
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetCallNodeAttribute(json_node, nodes.conv);
+ std::vector<std::string> kernel_const_name = {layer_name_ + "_const_" +
+ std::to_string(const_suffix_++)};
+ JsonNodeSetAttr(json_node, "kernel_const_name", kernel_const_name);
+
+ if (nodes.add) {
+ SetCallNodeAttribute(json_node, nodes.add);
+ std::vector<std::string> bias_const_name = {layer_name_ + "_const_" +
+ std::to_string(const_suffix_++)};
+ JsonNodeSetAttr(json_node, "bias_const_name", bias_const_name);
+ JsonNodeSetAttr(json_node, "bias_layout", {"---O"});
+ }
+ if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad);
+ if (nodes.batch_norm) SetMrvlLayerBatchnormAttrs(json_node, nodes.batch_norm);
+ if (nodes.activation) JsonNodeSetAttr(json_node, "activation_type", {"relu"});
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "", out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite sum2d.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateCompositeMrvlSum2DLayer(const CallNode* cn) {
+ CompositeSum2DNode nodes = UnpackCompositeSum2D(cn);
+ ICHECK(nodes.add != nullptr)
+ << "Marvell-Compiler-ERROR-Internal::attribute add can't be nullptr";
+
+ std::string mrvlLayerName = "Sum2D";
+ std::string name = "sum2d";
+ std::string data_layout;
+ std::string out_layout;
+ std::vector<int64_t> layout_vec;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ inputs.push_back(VisitExpr(cn->args[1])[0]);
+ GetInputTensorShapeViaArg0(cn, &layout_vec);
+ if (layout_vec.size() == 4) {
+ data_layout = "NHWC";
+ out_layout = "NHWC";
+ } else if (layout_vec.size() == 2) {
+ data_layout = "NC";
+ out_layout = "NC";
+ }
+
+ // add json node attributes
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetCallNodeAttribute(json_node, nodes.add);
+ if (nodes.activation) JsonNodeSetAttr(json_node, "activation_type", {"relu"});
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "", out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite reshape.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateMrvlReshapeLayer(const CallNode* cn) {
+ CompositeReshapeNode nodes = UnpackCompositeReshape(cn);
+
+ std::string name = "reshape";
+ std::string data_layout;
+ std::string out_layout;
+ std::vector<int64_t> layout_vec;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ GetInputTensorShapeViaArg0(nodes.reshape, &layout_vec);
+ ICHECK(layout_vec.size() == 2 || layout_vec.size() == 4)
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Reshape with input tensor dim != 2 or != 4 is not supported yet.";
+ if (layout_vec.size() == 4) {
+ data_layout = "NHWC";
+ } else {
+ data_layout = "NC";
+ }
+ layout_vec.clear();
+ GetOutputTensorShape(cn, &layout_vec);
+ ICHECK(layout_vec.size() == 2 || layout_vec.size() == 4)
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Reshape with output tensor dim != 2 or !=4 is not supported yet.";
+ if (layout_vec.size() == 4) {
+ out_layout = "NHWC";
+ } else {
+ out_layout = "NC";
+ }
+
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout,
+ "" /* no kernel_layout */, out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite batch flatten.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateMrvlBatchFlattenLayer(const CallNode* cn) {
+ CompositeBatchFlattenNode nodes = UnpackCompositeBatchFlatten(cn);
+
+ std::string name = "nn.batch_flatten";
+ std::string data_layout;
+ std::string out_layout = "NC";
+ std::vector<int64_t> layout_vec;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ GetInputTensorShapeViaArg0(nodes.batch_flatten, &layout_vec);
+ ICHECK(layout_vec.size() == 2 || layout_vec.size() == 4)
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "nn.batch_flatten with input tensor dim != 2 or != 4 is not supported yet.";
+ if (layout_vec.size() == 4) {
+ data_layout = "NHWC";
+ } else {
+ data_layout = "NC";
+ }
+ layout_vec.clear();
+ GetOutputTensorShape(cn, &layout_vec);
+ ICHECK(layout_vec.size() == 2)
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "nn.batch_flatten with output tensor dim != 2 is not supported yet.";
+
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout,
+ "" /* no kernel_layout */, out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite concat.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateMrvlConcatLayer(const CallNode* cn) {
+ CompositeConcatNode nodes = UnpackCompositeConcat(cn);
+ ICHECK(nodes.concat != nullptr)
+ << "Marvell-Compiler-ERROR-Internal::attribute concat can't be nullptr";
+
+ std::string mrvlLayerName = "Concat";
+ std::string name = "concat";
+ std::string data_layout;
+ std::string out_layout;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ for (auto arg : cn->args) {
+ inputs.push_back(VisitExpr(arg)[0]);
+ }
+
+ std::vector<int64_t> layout_vec;
+ GetInputTensorShapeViaArg0(cn, &layout_vec);
+ if (layout_vec.size() == 4) {
+ data_layout = "NHWC";
+ out_layout = "NHWC";
+ } else if (layout_vec.size() == 2) {
+ data_layout = "NC";
+ out_layout = "NC";
+ }
+
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetCallNodeAttribute(json_node, nodes.concat);
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "", out_layout);
+
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite LayoutTransform Reshape.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateMrvlLayoutTransposeReshapeLayer(const CallNode* cn) {
+ CompositeLayoutTransformReshapeNode nodes = UnpackCompositeLayoutTransposeReshape(cn);
+ ICHECK(nodes.transform != nullptr)
+ << "Marvell-Compiler-ERROR-Internal::attribute transform can't be nullptr";
+
+ std::string mrvlLayerName = "TransformReshape";
+ std::string name = "transformreshape";
+ std::string data_layout;
+ std::string out_layout = "NC";
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ auto layout_transform_attr = nodes.transform->attrs.as<LayoutTransformAttrs>();
+ data_layout = layout_transform_attr->src_layout;
+
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout,
+ "" /* no kernel_layout */, out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite fc (fully-connected) operator.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateCompositeMrvlFcLayer(const CallNode* cn) {
+ CompositeFcNode nodes = UnpackCompositeFc(cn);
+
+ std::string name = "nn.fc_ni2no";
+ std::string mrvlLayerName = "FC";
+ std::string data_layout = "NC";
+ std::string kernel_layout = "OI";
+ std::string out_layout = "NC";
+ std::string bias_layout = "-O";
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ inputs.push_back(VisitExpr(nodes.fc->args[1])[0]);
+ if (nodes.add) {
+ inputs.push_back(VisitExpr(nodes.add->args[1])[0]);
+ }
+
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ std::vector<std::string> kernel_const_name = {layer_name_ + "_const_" +
+ std::to_string(const_suffix_++)};
+ JsonNodeSetAttr(json_node, "kernel_const_name", kernel_const_name);
+ SetCallNodeAttribute(json_node, nodes.fc);
+ if (nodes.add) {
+ SetCallNodeAttribute(json_node, nodes.add);
+ std::vector<std::string> bias_const_name = {layer_name_ + "_const_" +
+ std::to_string(const_suffix_++)};
+ JsonNodeSetAttr(json_node, "bias_const_name", bias_const_name);
+ JsonNodeSetAttr(json_node, "bias_layout", {bias_layout});
+ }
+ if (nodes.activation) JsonNodeSetAttr(json_node, "activation_type", {"relu"});
+
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, kernel_layout,
+ out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite (global) maxpooling operator.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateCompositeMrvlMaxpool2DLayer(const CallNode* cn) {
+ std::string mrvlLayerName = "Maxpool2D";
+ CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName);
+ const auto* maxpool_attr = nodes.pool->attrs.as<MaxPool2DAttrs>();
+ std::string name = "nn.maxpool2d_nhwc2nhwc";
+ std::string data_layout = maxpool_attr->layout;
+ std::string out_layout = maxpool_attr->layout;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ ICHECK(maxpool_attr) << "Marvell-Compiler-ERROR-Internal::Downcast to MaxPool2DAttrs failed.";
+ ICHECK(maxpool_attr->layout == "NHWC")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Layout must be NHWC, has the module been pre-processed correctly?";
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetCallNodeAttribute(json_node, nodes.pool);
+ auto pool_attrs = nodes.pool->attrs.as<MaxPool2DAttrs>();
+ std::vector<int64_t> kernel_layout_vec;
+ kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[0])));
+ kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[1])));
+ JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec);
+ if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad);
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "HW",
+ out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite (global) avgpooling operator.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateCompositeMrvlAvgpool2DLayer(const CallNode* cn) {
+ std::string mrvlLayerName = "Avgpool2D";
+ CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName);
+ const auto* avgpool_attr = nodes.pool->attrs.as<AvgPool2DAttrs>();
+ std::string name = "nn.avgpool2d_nhwc2nhwc";
+ std::string data_layout = avgpool_attr->layout;
+ std::string out_layout = avgpool_attr->layout;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ ICHECK(avgpool_attr) << "Marvell-Compiler-ERROR-Internal::Downcast to AvgPool2DAttrs failed.";
+ ICHECK(avgpool_attr->layout == "NHWC")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Layout must be NHWC, has the module been pre-processed correctly?";
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetCallNodeAttribute(json_node, nodes.pool);
+ auto pool_attrs = nodes.pool->attrs.as<AvgPool2DAttrs>();
+ std::vector<int64_t> kernel_layout_vec;
+ kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[0])));
+ kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[1])));
+ JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec);
+ if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad);
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "HW",
+ out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of a composite globalavgpooling operator.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateCompositeMrvlGlobalAvgpool2DLayer(const CallNode* cn) {
+ std::string mrvlLayerName = "GlobalAvgpool2D";
+ CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName);
+ const auto* globalavgpool_attr = nodes.pool->attrs.as<GlobalPool2DAttrs>();
+ std::string name = "nn.globalavgpool2d_nhwc2nhwc";
+ std::string data_layout = globalavgpool_attr->layout;
+ std::string out_layout = globalavgpool_attr->layout;
+ std::vector<JSONGraphNodeEntry> inputs;
+
+ ICHECK(globalavgpool_attr)
+ << "Marvell-Compiler-ERROR-Internal::Downcast to GlobalPool2DAttrs failed.";
+ ICHECK(globalavgpool_attr->layout == "NHWC")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Layout must be NHWC, has the module been pre-processed correctly?";
+
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ std::vector<int64_t> kernel_layout_vec;
+ std::vector<int64_t> data_layout_vec;
+ GetInputTensorShapeViaArg0(cn, &data_layout_vec);
+ ICHECK(data_layout_vec.size() == 4);
+ kernel_layout_vec.push_back(data_layout_vec[1]);
+ kernel_layout_vec.push_back(data_layout_vec[2]);
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ SetCallNodeAttribute(json_node, nodes.pool);
+ JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec);
+ if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad);
+
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "HW",
+ out_layout);
+ return json_node;
+ }
+
+ /*!
+ * \brief Create a JSON representation of an OpNode layer.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateMrvlLayer4OpNode(const CallNode* cn) {
+ const auto* op_node = cn->op.as<OpNode>();
+ ICHECK(op_node) << "Marvell-Compiler-ERROR-Internal::Downcast to OpNode failed.";
+ String op_name = op_node->name;
+
+ std::string name = op_name;
+ std::string mrvlLayerName = op_name;
+ std::string data_layout = "";
+ std::string out_layout = "";
+ std::vector<JSONGraphNodeEntry> inputs;
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ // op_type_ is "kernel"
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+ if (op_name == "transpose") {
+ SetCallNodeAttribute(json_node, cn);
+ } else if (op_name == "layout_transform") {
+ SetCallNodeAttribute(json_node, cn);
+ auto layout_transform_attr = cn->attrs.as<LayoutTransformAttrs>();
+ data_layout = layout_transform_attr->src_layout;
+ out_layout = layout_transform_attr->dst_layout;
+ } else {
+ LOG(FATAL) << "Can't handle this OpNode: " << AsText(GetRef<Call>(cn), false);
+ }
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout,
+ "" /* no kernel_layout */, out_layout);
+ return json_node;
+ }
+};
+
+std::vector<std::string> split(const std::string& s, char delim) {
+ std::vector<std::string> result;
+ std::stringstream ss(s);
+ std::string item;
+ while (getline(ss, item, delim)) {
+ result.push_back(item);
+ }
+ return result;
+}
+
+/*!
+ * \brief Generate JSON meta files and then return a runtime module for Mrvl.
+ *
+ * \note This consists of a series of IR functions, which each represents
+ * a full Mrvl subgraph/region (in tvmc mode) or one fused Mrvl backend layer
+ * macro function (in dbg mode), that they can be computed on Mrvl accelerator.
+ *
+ * \param ref The ext_func Relay expression/module to be executed using extern ops.
+ * \return A runtime module.
+ */
+runtime::Module MrvlCompiler(const ObjectRef& ref) {
+ ICHECK(ref->IsInstance<FunctionNode>())
+ << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
+
+ Function func = Downcast<Function>(ref);
+ std::string func_name = backend::GetExtSymbol(func);
+ runtime::Module runtime_lib;
+
+ // Extract attributes from the frontend to be passed to the runtime
+ const std::string compiler_opt = func->GetAttr<String>("compiler_opts_string").value();
+ MrvlJSONSerializer serializer(func_name, func);
+ serializer.serialize();
+ std::string graph_json = serializer.GetJSON();
+
+ // Collect Nodes.json and Const.json
+ const auto* get_json = runtime::Registry::Get("tvm.mrvl.GetNodesJSONString");
+ std::string nodes_json_string = (*get_json)(graph_json);
+ auto consts_json_string = serializer.GetConstJSONString();
+
+ // Rename constants to a form acceptable by backend
+ const auto* modifyConsts = runtime::Registry::Get("tvm.mrvl.ModifyConstNames");
+ std::string modified_json = (*modifyConsts)(nodes_json_string, consts_json_string);
+ auto json_vec = split(modified_json, '|');
+
+ const auto* pf = runtime::Registry::Get("runtime.mrvl_runtime_create");
+ ICHECK(pf != nullptr) << "Cannot find software simulator runtime module to create";
+ runtime_lib = (*pf)(func_name, json_vec[0]);
+
+ return runtime_lib;
+}
+
+TVM_REGISTER_GLOBAL("relay.ext.mrvl").set_body_typed(MrvlCompiler);
+
+} // namespace mrvl
+} // namespace contrib
+
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/backend/contrib/mrvl/compiler_attr.cc b/src/relay/backend/contrib/mrvl/compiler_attr.cc
new file mode 100644
index 0000000000..4309212e33
--- /dev/null
+++ b/src/relay/backend/contrib/mrvl/compiler_attr.cc
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/mrvl/compiler_attr.cc
+ * \brief Marvell MLIP specific attributes
+ */
+
+#include <stdlib.h>
+#include <tvm/ir/transform.h>
+#include <tvm/target/target.h>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace mrvl {
+
+/*! \brief Attributes to store the compiler options for Mrvl MLIP */
+struct MrvlCompilerConfigNode : public tvm::AttrsNode<MrvlCompilerConfigNode> {
+ String mcpu;
+ IntImm num_tiles;
+ String mattr;
+ String working_dir;
+
+ TVM_DECLARE_ATTRS(MrvlCompilerConfigNode, "ext.attrs.MrvlCompilerConfigNode") {
+ TVM_ATTR_FIELD(mcpu)
+ .describe(
+ "The CPU class of Marvell(R) ML Inference Processor;"
+ "possible values = {cn10ka, cnf10kb}")
+ .set_default("cn10ka");
+ TVM_ATTR_FIELD(num_tiles)
+ .describe("Maximum number of tiles that may be used, possible values = {1,2,4,8}")
+ .set_default(IntImm(DataType::Int(64), 8));
+ TVM_ATTR_FIELD(mattr)
+ .describe("Attributes for MLIP; possible values = {quantize,wb_pin_ocm}")
+ .set_default("");
+ }
+};
+
+class MrvlCompilerConfig : public Attrs {
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MrvlCompilerConfig, Attrs, MrvlCompilerConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(MrvlCompilerConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.mrvl.options", MrvlCompilerConfig);
+
+TVM_REGISTER_TARGET_KIND("mrvl", kDLCPU);
+
+} // namespace mrvl
+} // namespace contrib
+} // namespace relay
+} // namespace tvm
diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc
new file mode 100644
index 0000000000..89e8ff108e
--- /dev/null
+++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/mrvl/mrvl_runtime.cc
+ * \brief runtime implementation for Marvell Software Simulator.
+ */
+
+#include <assert.h>
+#include <ctype.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "../json/json_node.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+/*!
+ * \brief A json runtime that compiles the serialized JSON format to a binary for Marvell
+hardware and then runs the generated binary using the Marvell software simulator (MlModel).
+ * \param symbol_name The name of the subgraph / relay function
+ * \param nodes_json The serialized JSON representation of relay function
+ */
+
+class MarvellSimulatorModuleNode : public ModuleNode {
+ public:
+ MarvellSimulatorModuleNode(const std::string& symbol_name, const std::string& nodes_json)
+ : symbol_name_(symbol_name), nodes_json_(nodes_json) {}
+
+ const char* type_key() const { return "mrvl_sim"; }
+
+ /*! \brief Get the property of the runtime module .*/
+ int GetPropertyMask() const final {
+ return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable;
+ }
+
+ /*!
+ * \brief Get a packed function.
+ * \param name The name/symbol of the function.
+ * \param sptr_to_self The pointer to the module node.
+ * \return The packed function.
+ */
+ virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
+ if (name == "get_symbol") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; });
+ } else if (name == "get_const_vars") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>{}; });
+ } else if (this->symbol_name_ == name) {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ Run(args);
+ *rv = 0;
+ });
+ }
+ return PackedFunc(nullptr);
+ }
+
+ virtual void SaveToBinary(dmlc::Stream* stream) {
+ // Save the symbol name and other data and serialize them to
+ // binary format.
+ stream->Write(symbol_name_);
+ stream->Write(nodes_json_);
+ }
+
+ static Module LoadFromBinary(void* strm) {
+ dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+ std::string symbol_name;
+ std::string nodes_json;
+ // Load the symbol_name and other data to construct the module
+ ICHECK(stream->Read(&symbol_name))
+ << "Marvell-Compiler-ERROR-Internal::Loading symbol name failed";
+ ICHECK(stream->Read(&nodes_json))
+ << "Marvell-Compiler-ERROR-Internal::Loading nodes json failed";
+ auto n = make_object<MarvellSimulatorModuleNode>(symbol_name, nodes_json);
+ return Module(n);
+ }
+
+ /*!
+ * \brief Get the source generated by codegen.
+ *
+ * \param format the format to return.
+ * \return A string of JSON.
+ */
+ String GetSource(const String& format = "json") override { return nodes_json_; }
+
+ protected:
+ std::string symbol_name_;
+ std::string nodes_json_;
+
+ void Run(TVMArgs args) {
+ ICHECK(false) << "Marvell-Compiler-ERROR-Internal::Run not supported for Marvell Runtime yet!";
+ }
+};
+
+runtime::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name,
+ const String& nodes_json) {
+ auto n = make_object<MarvellSimulatorModuleNode>(symbol_name, nodes_json);
+ return runtime::Module(n);
+}
+
+TVM_REGISTER_GLOBAL("runtime.mrvl_runtime_create")
+ .set_body_typed(MarvellSimulatorModuleRuntimeCreate);
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_sim")
+ .set_body_typed(MarvellSimulatorModuleNode::LoadFromBinary);
+} // namespace contrib
+} // namespace runtime
+} // namespace tvm
diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc
index b9b3fe3059..cc84f7a675 100644
--- a/src/support/libinfo.cc
+++ b/src/support/libinfo.cc
@@ -167,6 +167,10 @@
#define TVM_INFO_USE_MKL "NOT-FOUND"
#endif
+#ifndef TVM_INFO_USE_MRVL
+#define TVM_INFO_USE_MRVL "NOT-FOUND"
+#endif
+
#ifndef TVM_INFO_USE_AMX
#define TVM_INFO_USE_AMX "NOT-FOUND"
#endif
@@ -327,6 +331,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"USE_MICRO", TVM_INFO_USE_MICRO},
{"USE_MIOPEN", TVM_INFO_USE_MIOPEN},
{"USE_MKL", TVM_INFO_USE_MKL},
+ {"USE_MRVL", TVM_INFO_USE_MRVL},
{"USE_MSVC_MT", TVM_INFO_USE_MSVC_MT},
{"USE_NNPACK", TVM_INFO_USE_NNPACK},
{"USE_OPENCL", TVM_INFO_USE_OPENCL},
diff --git a/python/tvm/relay/op/contrib/__init__.py b/tests/python/contrib/test_mrvl/__init__.py
similarity index 70%
copy from python/tvm/relay/op/contrib/__init__.py
copy to tests/python/contrib/test_mrvl/__init__.py
index 01708e8452..736bad9370 100644
--- a/python/tvm/relay/op/contrib/__init__.py
+++ b/tests/python/contrib/test_mrvl/__init__.py
@@ -14,16 +14,4 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=wildcard-import
-"""Contrib modules."""
-from .register import get_pattern_table, register_pattern_table
-
-from .arm_compute_lib import *
-from .dnnl import *
-from .bnns import *
-from .coreml import *
-from .ethosn import *
-from .libtorch import *
-from .tensorrt import *
-from .cutlass import *
-from .clml import *
+"""Infrastructure and tests for Marvell"""
diff --git a/tests/python/contrib/test_mrvl/infrastructure.py b/tests/python/contrib/test_mrvl/infrastructure.py
new file mode 100644
index 0000000000..c46753d4e7
--- /dev/null
+++ b/tests/python/contrib/test_mrvl/infrastructure.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
+
+"""Infrastructure to Test Marvell Code Generation"""
+import json
+import os
+
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib import mrvl
+
+
+def get_cpu_op_count(mod):
+ """Traverse graph counting ops offloaded to TVM."""
+
+ class Counter(tvm.relay.ExprVisitor):
+ def __init__(self):
+ super().__init__()
+ self.count = 0
+
+ def visit_call(self, call):
+ if isinstance(call.op, tvm.ir.Op):
+ self.count += 1
+
+ super().visit_call(call)
+
+ c = Counter()
+ c.visit(mod["main"])
+ return c.count
+
+
+def build_module(
+ mod,
+ target,
+ params=None,
+ enable_mrvl=True,
+ tvm_ops=0,
+ mrvl_partitions=1,
+):
+ """Partition and build module for mrvl codegen."""
+ if isinstance(mod, tvm.relay.expr.Call):
+ mod = tvm.IRModule.from_expr(mod)
+ if params is None:
+ params = {}
+
+ with tvm.transform.PassContext(opt_level=3):
+ if enable_mrvl:
+ mod = mrvl.partition_for_mrvl(mod, params)
+ tvm_op_count = get_cpu_op_count(mod)
+ assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format(
+ tvm_op_count, tvm_ops
+ )
+ partition_count = 0
+ for global_var in mod.get_global_vars():
+ if "mrvl" in global_var.name_hint:
+ partition_count += 1
+
+ assert mrvl_partitions == partition_count, "Got {} mrvl partitions, expected {}".format(
+ partition_count, mrvl_partitions
+ )
+ return relay.build(mod, target, params=params)
+
+
+def extract_mrvl_modules(module):
+ """Get a list of all built mrvl runtime modules."""
+ return list(filter(lambda mod: mod.type_key == "mrvl_sim", module.get_lib().imported_modules))
+
+
+def verify_codegen(
+ module, num_mrvl_modules=1, params=None, target="llvm", tvm_ops=0, contains=None
+):
+ """Check mrvl codegen against a known good output."""
+ module = build_module(
+ module,
+ target,
+ params=params,
+ tvm_ops=tvm_ops,
+ mrvl_partitions=num_mrvl_modules,
+ )
+
+ mrvl_modules = extract_mrvl_modules(module)
+ assert len(mrvl_modules) == num_mrvl_modules, (
+ f"The number of mrvl modules produced ({len(mrvl_modules)}) does not "
+ f"match the expected value ({num_mrvl_modules})."
+ )
+
+ # Check if expected string is found inside actual string
+ if contains is not None:
+ actual_str = json.dumps(json.loads(mrvl_modules[0].get_source()))
+ assert actual_str.find(contains)
diff --git a/tests/python/contrib/test_mrvl/test_mrvl.py b/tests/python/contrib/test_mrvl/test_mrvl.py
new file mode 100644
index 0000000000..03fdcedc93
--- /dev/null
+++ b/tests/python/contrib/test_mrvl/test_mrvl.py
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
+
+"""Test Marvell BYOC partitioning, code generation and runtime"""
+
+import numpy as np
+
+import tvm
+from tvm import relay
+import tvm.relay.testing
+from tvm.testing.utils import requires_mrvl
+from tvm.relay.op.contrib.mrvl import partition_for_mrvl
+from .infrastructure import verify_codegen
+from tvm.testing import requires_mrvl
+
+
+@requires_mrvl
+def test_mrvl_fuse():
+ def get_blocks(
+ prefix,
+ data,
+ in_channel,
+ out_channel,
+ include_bias_add=True,
+ include_bn=True,
+ include_sigmoid=False,
+ ):
+ weight = relay.var(prefix + "weight")
+ bias = relay.var(prefix + "bias")
+ bn_gamma = relay.var(prefix + "bn_gamma")
+ bn_beta = relay.var(prefix + "bn_beta")
+ bn_mmean = relay.var(prefix + "bn_mean")
+ bn_mvar = relay.var(prefix + "bn_var")
+
+ layer = relay.nn.conv2d(
+ data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1)
+ )
+ if include_bias_add:
+ layer = relay.nn.bias_add(layer, bias)
+ if include_bn:
+ bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+ layer = bn_output[0]
+ if include_sigmoid:
+ layer = relay.sigmoid(layer)
+ layer = relay.nn.relu(layer)
+ return layer
+
+ def get_net(include_bias_add=True, include_bn=True, include_sigmoid=False):
+ data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+ block1 = get_blocks("block1_", data, 3, 8, include_bias_add, include_bn, include_sigmoid)
+ block2 = get_blocks("block2_", block1, 8, 8, False, False, include_sigmoid)
+ return relay.Function(relay.analysis.free_vars(block2), block2)
+
+ def test_detect_pattern(include_bias_add, include_bn, include_sigmoid, num_expected_partition):
+ net = get_net(include_bias_add, include_bn, include_sigmoid)
+ mod, params = tvm.relay.testing.create_workload(net)
+ mod = partition_for_mrvl(mod, params)
+ assert len(mod.functions) - 1 == num_expected_partition
+
+ def test_sum_pattern(num_expected_partition):
+ def get_conv2d_bn_sum_relu(
+ x_shape=(1, 32, 8, 8),
+ k_shape=(16, 32, 3, 3),
+ sum_shape=(1, 16, 6, 6),
+ dtype="float32",
+ ):
+ x = relay.var("x", shape=(x_shape), dtype=dtype)
+ kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype))
+ bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
+ beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
+ gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
+ moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
+ moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
+ sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
+
+ dic = {"x": x_shape, "bias": (k_shape[0],), "sum_data": sum_shape}
+ param_lst = ["bias", "sum_data"]
+
+ conv = relay.nn.conv2d(
+ x,
+ kernel,
+ channels=k_shape[0],
+ kernel_size=k_shape[2:4],
+ )
+ conv_bias = relay.nn.bias_add(conv, bias)
+ conv_bias_bn, _, _ = relay.nn.batch_norm(
+ conv_bias,
+ gamma=gamma,
+ beta=beta,
+ moving_mean=moving_mean,
+ moving_var=moving_var,
+ axis=1,
+ center=True,
+ scale=True,
+ epsilon=1e-5,
+ )
+ conv_bias_bn_sum = relay.add(conv_bias_bn, sum_data)
+ return relay.nn.relu(conv_bias_bn_sum), dic, param_lst
+
+ net, dic, param_lst = get_conv2d_bn_sum_relu()
+ net = tvm.IRModule.from_expr(net)
+ params = {x: np.random.uniform(-1, 1, dic[x]).astype("float32") for x in param_lst}
+ mod = partition_for_mrvl(net, params)
+ assert len(mod.functions) - 1 == num_expected_partition
+
+ def test_partition():
+ test_detect_pattern(True, False, False, 1)
+ test_detect_pattern(False, True, False, 1)
+ test_detect_pattern(False, False, True, 2)
+ test_detect_pattern(True, True, False, 1)
+ test_detect_pattern(True, False, True, 2)
+ test_detect_pattern(False, True, True, 2)
+ test_detect_pattern(False, False, False, 1)
+ test_detect_pattern(True, True, True, 2)
+ test_sum_pattern(1)
+
+ def test_partition_mobilenet(num_expected_partition):
+ mod, params = relay.testing.mobilenet.get_workload()
+ mod = partition_for_mrvl(mod, params)
+ assert len(mod.functions) - 1 == num_expected_partition
+
+ test_partition()
+ test_partition_mobilenet(1)
+
+
+@requires_mrvl
+def test_conv2d():
+ """Test conv2d operator for "mrvl" targets"""
+
+ x = relay.var("x", shape=(1, 3, 224, 224))
+ w = relay.const(np.zeros((16, 3, 3, 3), dtype="float32"))
+ y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3])
+ func = relay.Function([x], y)
+ params = {}
+ params["w"] = np.random.rand(16, 3, 3, 3).astype("float32")
+ mod = tvm.IRModule()
+ mod["main"] = func
+ verify_codegen(mod, params=params, tvm_ops=1, contains="mrvl.conv2d_nhwc2nhwc")
+
+
+@requires_mrvl
+def test_dense():
+ """Test dense operator for "mrvl" targets"""
+
+ x = relay.var("x", shape=(1, 16))
+ w = relay.const(np.zeros((32, 16), dtype="float32"))
+ y = relay.nn.dense(x, w)
+ func = relay.Function([x], y)
+ params = {}
+ params["w"] = np.random.rand(16, 3, 3, 3).astype("float32")
+ mod = tvm.IRModule()
+ mod["main"] = func
+ verify_codegen(mod, params=params, tvm_ops=0, contains="mrvl.fc_ni2no")
+
+
+if __name__ == "__main__":
+ test_mrvl_fuse()
+ test_conv2d()
+ test_dense()
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index ca56172cb7..4bf6b27ccf 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -870,6 +870,24 @@ def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v
assert os.path.exists(dumps_path)
+@tvm.testing.requires_mrvl
+def test_compile_pytorch_module_with_external_codegen_mrvl(pytorch_resnet18):
+ tvmc_model = tvmc.load(pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]})
+ tvmc_package = tvmc.compiler.compile_model(
+ tvmc_model,
+ target="mrvl, llvm",
+ dump_code="relay",
+ )
+ dumps_path = tvmc_package.package_path + ".relay"
+
+ # check for output types
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
+
+
def test_compile_tflite_module_with_external_codegen_ethosu(
tmpdir_factory, tflite_mobilenet_v1_1_quant
):
diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py
index ca08d3e66f..2335563b3e 100644
--- a/tests/python/driver/tvmc/test_composite_target.py
+++ b/tests/python/driver/tvmc/test_composite_target.py
@@ -35,6 +35,7 @@ def test_get_codegen_names():
assert "ethos-n" in names
assert "vitis-ai" in names
+ assert "mrvl" in names
assert len(names) > 0
diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py
index d3df83f346..194047e7a6 100644
--- a/tests/python/driver/tvmc/test_target_options.py
+++ b/tests/python/driver/tvmc/test_target_options.py
@@ -53,6 +53,25 @@ def test_target_to_argparse_known_codegen():
assert parsed.target_cmsis_nn_mcpu == "cortex-m3"
+@tvm.testing.requires_mrvl
+def test_target_to_argparse_for_mrvl_hybrid():
+ parser = argparse.ArgumentParser()
+ generate_target_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--target=mrvl, llvm",
+ "--target-mrvl-mattr=wb_pin_ocm=1,quantize=fp16",
+ "--target-mrvl-num_tiles=2",
+ "--target-mrvl-mcpu=cnf10kb",
+ ]
+ )
+
+ assert parsed.target == "mrvl, llvm"
+ assert parsed.target_mrvl_mattr == "wb_pin_ocm=1,quantize=fp16"
+ assert parsed.target_mrvl_num_tiles == 2
+ assert parsed.target_mrvl_mcpu == "cnf10kb"
+
+
def test_mapping_target_args():
parser = argparse.ArgumentParser()
generate_target_args(parser)
diff --git a/python/tvm/relay/op/contrib/__init__.py b/tests/scripts/task_config_build_mrvl.sh
old mode 100644
new mode 100755
similarity index 60%
copy from python/tvm/relay/op/contrib/__init__.py
copy to tests/scripts/task_config_build_mrvl.sh
index 01708e8452..cb5adeab38
--- a/python/tvm/relay/op/contrib/__init__.py
+++ b/tests/scripts/task_config_build_mrvl.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -14,16 +15,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=wildcard-import
-"""Contrib modules."""
-from .register import get_pattern_table, register_pattern_table
-from .arm_compute_lib import *
-from .dnnl import *
-from .bnns import *
-from .coreml import *
-from .ethosn import *
-from .libtorch import *
-from .tensorrt import *
-from .cutlass import *
-from .clml import *
+set -euxo pipefail
+
+BUILD_DIR=$1
+mkdir -p "$BUILD_DIR"
+cd "$BUILD_DIR"
+cp ../cmake/config.cmake .
+
+echo set\(USE_LLVM \"/usr/bin/llvm-config-15 --link-static\"\) >> config.cmake
+echo set\(CMAKE_CXX_FLAGS \"-Werror -Wno-error=range-loop-construct\"\) >> config.cmake
+echo set\(USE_LIBBACKTRACE COMPILE\) >> config.cmake
+echo set\(BACKTRACE_ON_SEGFAULT ON\) >> config.cmake
+
+# Enable Mrvl target
+echo set\(USE_MRVL ON\) >> config.cmake
+echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake