You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/01 16:25:46 UTC

[GitHub] [tvm] mbaret opened a new pull request #10871: [CUDNN] Add cuDNN as a Relay partitioning target (BYOC)

mbaret opened a new pull request #10871:
URL: https://github.com/apache/tvm/pull/10871


   This adds infrastructure to support offloading of Relay patterns to cuDNN. In this initial commit, only softmax is supported. Later PRs will include support for more operators, including some limited fused patterns.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #10871: [CUDNN] Add cuDNN as a Relay partitioning target (BYOC)

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #10871:
URL: https://github.com/apache/tvm/pull/10871#discussion_r840770963



##########
File path: python/tvm/relay/op/contrib/cudnn.py
##########
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""cuDNN Relay integration."""
+from typing import Callable, List, Tuple, Dict, Optional
+
+import tvm
+import tvm.ir
+from tvm import relay
+from tvm import te
+from tvm.relay import transform
+from tvm.contrib import cudnn
+
+from ...dataflow_pattern import is_op, wildcard
+from .register import register_pattern_table
+
+
+def partition_for_cudnn(
+    mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
+) -> tvm.IRModule:
+    """Partition the graph to offload for cuDNN.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The module to partition.
+    params : Optional[Dict[str, tvm.runtime.NDArray]]
+        Constant input parameters.
+
+    Returns
+    -------
+    tvm.IRModule
+        The partitioned module.
+    """
+
+    seq = tvm.transform.Sequential(
+        [
+            transform.InferType(),
+            transform.MergeComposite(pattern_table()),
+            transform.AnnotateTarget("cudnn"),
+            transform.PartitionGraph(),
+            transform.InferType(),
+        ]
+    )
+    return seq(mod)
+
+
+@register_pattern_table("cudnn")
+def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], bool]]]:
+    """Get the cuDNN pattern table."""
+
+    def softmax_pattern() -> relay.Pattern:
+        """Create pattern for softmax."""
+        return is_op("nn.softmax")(wildcard())
+
+    def check_softmax(matched: relay.Call) -> bool:
+        """Check if softmax is supported by cuDNN."""
+        if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
+            return False
+
+        return True
+
+    return [
+        ("cudnn.softmax", softmax_pattern(), check_softmax),
+    ]
+
+
+_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
+_LOWER_MAP: Dict[str, _LowerFunc] = {}
+
+
+def _lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
+    """Register a lowering function for a given composite function name."""
+
+    def _register(f: _LowerFunc) -> _LowerFunc:
+        _LOWER_MAP[comp_name] = f
+        return f
+
+    return _register
+
+
+@tvm._ffi.register_func("relay.ext.cudnn")

Review comment:
       Now would be a good time to hoist this boilerplate into a library_byoc.py helper or something similar?

##########
File path: python/tvm/relay/op/contrib/cudnn.py
##########
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""cuDNN Relay integration."""
+from typing import Callable, List, Tuple, Dict, Optional
+
+import tvm
+import tvm.ir
+from tvm import relay
+from tvm import te
+from tvm.relay import transform
+from tvm.contrib import cudnn
+
+from ...dataflow_pattern import is_op, wildcard
+from .register import register_pattern_table
+
+
+def partition_for_cudnn(
+    mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
+) -> tvm.IRModule:
+    """Partition the graph to offload for cuDNN.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The module to partition.
+    params : Optional[Dict[str, tvm.runtime.NDArray]]
+        Constant input parameters.
+
+    Returns
+    -------
+    tvm.IRModule
+        The partitioned module.
+    """
+
+    seq = tvm.transform.Sequential(
+        [
+            transform.InferType(),
+            transform.MergeComposite(pattern_table()),
+            transform.AnnotateTarget("cudnn"),
+            transform.PartitionGraph(),
+            transform.InferType(),
+        ]
+    )
+    return seq(mod)
+
+
+@register_pattern_table("cudnn")
+def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], bool]]]:
+    """Get the cuDNN pattern table."""
+
+    def softmax_pattern() -> relay.Pattern:
+        """Create pattern for softmax."""
+        return is_op("nn.softmax")(wildcard())
+
+    def check_softmax(matched: relay.Call) -> bool:
+        """Check if softmax is supported by cuDNN."""
+        if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
+            return False
+
+        return True
+
+    return [
+        ("cudnn.softmax", softmax_pattern(), check_softmax),
+    ]
+
+
+_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
+_LOWER_MAP: Dict[str, _LowerFunc] = {}
+
+
+def _lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
+    """Register a lowering function for a given composite function name."""
+
+    def _register(f: _LowerFunc) -> _LowerFunc:
+        _LOWER_MAP[comp_name] = f
+        return f
+
+    return _register
+
+
+@tvm._ffi.register_func("relay.ext.cudnn")
+def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
+    """Compile cuDNN Relay functions to a runtime module."""
+    assert isinstance(partition, relay.Function)
+    assert isinstance(partition.body, relay.Call)
+    assert isinstance(partition.body.op, relay.Function)
+
+    global_name = str(partition.attrs.global_symbol)
+    target = tvm.target.cuda()

Review comment:
       Just notice this, I think Target.current() is better so that cuda params are not lost.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #10871: [CUDNN] Add cuDNN as a Relay partitioning target (BYOC)

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #10871:
URL: https://github.com/apache/tvm/pull/10871#issuecomment-1086112542


   cc @mikepapadim @mbs-octoml 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org