You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/20 15:41:23 UTC

[GitHub] [tvm] lhutton1 commented on a diff in pull request #12087: [UMA] UMA v1.0

lhutton1 commented on code in PR #12087:
URL: https://github.com/apache/tvm/pull/12087#discussion_r924943667


##########
python/tvm/relay/backend/contrib/uma/api/partitioner.py:
##########
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Partitioner base class of the Universal Modular Accelerator Interface (UMA)"""
+
+from typing import Callable, Dict, List, Tuple, Optional
+
+import tvm
+from tvm import relay
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.op.contrib.register import register_pattern_table
+from .utils import PassPhase
+
+
+PatternTable = List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]
+
+
+class UMAPartitioner:
+    """Partitioner base class of the Universal Modular Accelerator Interface (UMA)."""
+
+    def __init__(self, target_name: str, merge_compiler_regions: bool = True) -> None:
+        self.target_name = target_name
+        self.merge_compiler_regions = merge_compiler_regions
+
+        self._relay_passes: List[Tuple[PassPhase, tvm.transform.Pass]] = []
+        self._patterns: PatternTable = []
+
+    def add_pattern(
+        self,
+        name: str,
+        pattern: tvm.relay.dataflow_pattern.DFPattern,
+        predicate: Optional[Callable] = None,
+    ) -> None:
+        """Add pattern to UMA partitioner
+
+        Parameters
+        ----------
+        name : str
+            relay name of pattern
+
+        pattern: tvm.relay.dataflow_pattern.DFPattern
+            pattern description as DFPattern
+
+        predicate: Optional[Callable]
+            Optional predicate
+
+        """
+
+        name = self.target_name + "." + name
+        if predicate:
+            self._patterns.append((name, pattern, predicate))
+        else:
+            self._patterns.append((name, pattern))
+
+    def _pattern_table(self) -> PatternTable:
+        return self._patterns
+
+    def register(self) -> None:
+        """Register all relevant relay-to-relay functions."""
+        register_pattern_table(self.target_name, self._pattern_table)
+
+    def partition(
+        self, mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
+    ) -> tvm.IRModule:
+        """Partition the relay graph in by the NPU supported and unsupported parts.
+
+        Parameters
+        ----------
+        mod : tvm.IRModule
+            The relay module to be partitioned.
+
+        params: Optional[Dict[str, tvm.runtime.NDArray]]
+
+        Returns
+        -------
+        out : tvm.IRModule
+            The partitioned relay module.
+
+        """
+        if params:
+            mod["main"] = bind_params_by_name(mod["main"], params)
+
+        mod = relay.transform.InferType()(mod)

Review Comment:
   I wonder if using a single `Sequential` block might make things a bit clearer here. e.g.
   ```
   seq = []
   seq.extend([p[1] for p in self._relay_passes if p[0] == PassPhase.PRE_PARTITIONING])
   seq.append(relay.transform.MergeComposite(self._pattern_table()))
   ....
   mod = seq(mod)
   ```
   Further, it'll reduce the need to explicitly run `InferType` and manage other pass dependencies



##########
python/tvm/relay/backend/contrib/uma/api/lower.py:
##########
@@ -0,0 +1,159 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Lowering base class of the Universal Modular Accelerator Interface (UMA)"""
+
+from typing import List, Tuple, Callable, Optional
+
+import tvm
+from tvm import relay, te
+from tvm.relay.op.op import register_strategy
+from . import _ffi_api
+from .utils import PassPhase
+
+
+class UMALower:
+    """Lowering base class of the Universal Modular Accelerator Interface (UMA)."""
+
+    def __init__(self, target_name: str) -> None:
+        self.target_name = target_name
+
+        self._operator_strategies: List[
+            Tuple[
+                str,
+                Callable[
+                    [tvm.ir.Attrs, tvm.ir.Array, tvm.ir.TensorType, tvm.target.Target],
+                    tvm.relay.op.op.OpStrategy,
+                ],
+                Optional[int],
+            ]
+        ] = []

Review Comment:
   Would it make sense to break these type annotations into smaller named variables? e.g. like `PatternTable` in partitioner.py



##########
python/tvm/relay/backend/contrib/uma/backend.py:
##########
@@ -0,0 +1,299 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Backend base class of the Universal Modular Accelerator Interface (UMA)"""
+
+from abc import ABC, abstractmethod
+from typing import Union, Dict, Callable, Optional, Any
+
+import tvm
+from tvm.relay.backend.contrib.uma.api.codegen import UMACodegen
+from tvm.relay.backend.contrib.uma.api.lower import UMALower
+from tvm.relay.backend.contrib.uma.api.partitioner import UMAPartitioner
+from tvm.relay.backend.contrib.uma.api.utils import PassPhase
+
+
+class UMABackend(ABC):
+    """Backend base class of the Universal Modular Accelerator Interface (UMA)"""
+
+    def __init__(self, merge_compiler_regions: bool = True) -> None:
+        self._target_attrs: Dict = {}
+        self._target_preprocessor: Callable[[str], Dict[str, Any]] = None
+        self._relay_to_relay = UMAPartitioner(self.target_name, merge_compiler_regions)
+        self._relay_to_tir = UMALower(self.target_name)
+        self._tir_to_runtime = UMACodegen(self.target_name)
+
+    @property
+    @abstractmethod
+    def target_name(self) -> str:
+        """Name of the hardware target.
+
+        Returns
+        -------
+        out : str
+            The hardware target name.
+        """
+        ...
+
+    ############################################################################
+    # Target configuration
+    ############################################################################
+    def _register_target_attr(
+        self,
+        name: str,
+        default: Optional[Union[str, int, bool]] = "",
+    ) -> None:
+        """Register a target attribute name that can be used during target instantiation.
+        Parameters
+        ----------
+        name: str
+           The name of the target attribute.
+
+        default: Optional[Union[str, int, bool]]
+            A default value for the attribute.
+            If none is provided, the attribute will be treated as a string.
+
+        Example
+        -------
+        Here is an example of how two attribute options are registered.
+
+        .. code-block:: python
+
+            self._register_target_attr("attrA", default=0)
+            self._register_target_attr("attrB", default=False)
+        """
+        self._target_attrs[name] = default
+
+    ############################################################################
+    # Relay to Relay function registration
+    ############################################################################
+    def _register_relay_pass(self, phase: PassPhase, relay_pass: tvm.transform.Pass) -> None:
+        """Registers a relay pass at the given phase in the lowering process.
+
+        Parameters
+        ----------
+        phase: PassPhase
+           The phase at which the pass is registered.
+
+        relay_pass: tvm.transform.Pass
+            The relay pass to be registered.
+
+        Example
+        -------
+        Here is an example of how two relay passes are registered.
+        Passes of the same phase are executed in the order they are registered.
+
+        .. code-block:: python
+
+            self._register_relay_pass(PassPhase.PRE_PARTITIONING, MyPassA)
+            self._register_relay_pass(PassPhase.POST_PARTITIONING, MyPassB)
+
+        Where a relay pass can look like this:
+
+        .. code-block:: python
+
+            @tvm.ir.transform.module_pass(opt_level=0)
+            class MyPassA:
+                def transform_module(self, mod, ctx):
+                    # My pass functionality...
+                    return mod
+        """
+        self._relay_to_relay._relay_passes.append((phase, relay_pass))
+
+    def _register_pattern(
+        self,
+        name: str,
+        pattern: tvm.relay.dataflow_pattern.DFPattern,
+        predicate: Optional[Callable] = None,
+    ) -> None:
+        """Registers a dataflow pattern that is used to partition the relay graph.
+
+        Parameters
+        ----------
+        name: str
+           The name of the pattern.
+
+        pattern: tvm.relay.dataflow_pattern.DFPattern
+            The dataflow pattern.
+
+        predicate: Callable Receiving the matched pattern and

Review Comment:
   nit: missing part of sentence?



##########
tests/scripts/task_python_uma.sh:
##########
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -euxo pipefail
+
+source tests/scripts/setup-pytest-env.sh
+
+run_pytest ctypes test_uma tests/python/contrib/test_uma
+run_pytest cython3 test_uma  tests/python/contrib/test_uma

Review Comment:
   Curious if there is a need to run these tests separately? I believe they should already get run as part of `task_python_integration` (https://github.com/apache/tvm/blob/main/tests/scripts/task_python_integration.sh#L64)?



##########
tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py:
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""UMA testcase for the vanilla_accelerator accelerator"""
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.relay.dataflow_pattern import is_op, wildcard
+from tvm.relay.backend.contrib.uma.api.utils import PassPhase
+from tvm.relay.backend.contrib.uma.backend import UMABackend
+from tvm.relay.backend.contrib.uma._template.passes import (
+    MyAiHwConv2dPass as VanillaAcceleratorConv2dPass,
+)
+from tvm.relay.backend.contrib.uma._template.codegen import gen_includes
+
+from tvm.relay.backend.contrib.uma._template.patterns import conv2d_pattern
+
+# def conv2d_pattern():
+#     pattern = is_op("nn.conv2d")(wildcard(), wildcard())
+#     pattern = pattern.has_attr({"strides": [1, 1]})
+#     return pattern

Review Comment:
   nit: remove



##########
tests/python/contrib/test_uma/test_partition.py:
##########
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+

Review Comment:
   Should the tests be guarded in some way based on whether or not TVM was compiled with UMA support?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org