You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/10 17:53:34 UTC

[GitHub] [tvm] areusch commented on a change in pull request #7612: [WIP][Docs] Getting Started With TVM Tutorial

areusch commented on a change in pull request #7612:
URL: https://github.com/apache/tvm/pull/7612#discussion_r591740619



##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -223,22 +295,21 @@
 cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
 print(temp.listdir())
 
-######################################################################
+################################################################################
 # .. note:: Module Storage Format
 #
-#   The CPU (host) module is directly saved as a shared library (.so).
-#   There can be multiple customized formats of the device code.
-#   In our example, the device code is stored in ptx, as well as a meta
-#   data json file. They can be loaded and linked separately via import.
-#
+#   The CPU (host) module is directly saved as a shared library (.so). There

Review comment:
       I guess i'd also generally support a split of tutorials at each point where we could save a meaningful artifact to disk. e.g. after Relay import, after tvm.relay.build.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -163,52 +145,155 @@
 fadd(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Inspect the Generated Code
-# --------------------------
-# You can inspect the generated code in TVM. The result of tvm.build
-# is a TVM Module. fadd is the host module that contains the host wrapper,
-# it also contains a device module for the CUDA (GPU) function.
-#
-# The following code fetches the device module and prints the content code.
-#
-if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
-    dev_module = fadd.imported_modules[0]
-    print("-----GPU code-----")
-    print(dev_module.get_source())
-else:
-    print(fadd.get_source())
+################################################################################
+# Updating the Schedule to Use Paralleism
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Now that we've illustrated the fundamentals of TE, let's go deeper into what
+# schedules do, and how they can be used to optimize tensor expressions for
+# different architectures. A schedule is a series of steps that are applied to
+# an expression to transform it in a number of different ways. When a schedule
+# is applied to an expression in TE, the inputs and outputs remain the same,
+# but when compiled the implementation of the expression can change. This
+# tensor addition, in the default schedule, is run serially but is easy to
+# parallelize across all of the processor threads. We can apply the parallel
+# schedule operation to our computation.
 
-######################################################################
-# .. note:: Code Specialization
-#
-#   As you may have noticed, the declarations of A, B and C all
-#   take the same shape argument, n. TVM will take advantage of this
-#   to pass only a single shape argument to the kernel, as you will find in
-#   the printed device code. This is one form of specialization.
+s[C].parallel(C.op.axis[0])
+
+################################################################################
+# The ``tvm.lower`` command will generate the Intermediate Representation (IR)
+# of the TE, with the corresponding schedule. By lowering the expression as we
+# apply different schedule operations, we can see the effect of scheduling on
+# the ordering of the computation.
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# It's now possible for TVM to run these blocks on independent threads. Let's
+# compile and run this new schedule with the parallel operation applied:
+
+fadd_parallel = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd_parallel")
+fadd_parallel(a, b, c)
+
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Vectorization
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Modern CPUs also have the ability to perform SIMD operations on floating
+# point values, and we can apply another schedule to our computation expression
+# to take advantage of this. Accomplishing this requires multiple steps: first
+# we have to split the schedule into inner and outer loops using the split
+# scheduling primitive. The inner loops can use vectorization to use SIMD
+# instructions using the vectorize scheduling primitive, then the outer loops
+# can be parallelized using the parallel scheduling primitive. Choose the split
+# factor to be the number of threads on your CPU.
+
+# Recreate the schedule, since we modified it with the parallel operation in the previous example
+n = te.var("n")
+A = te.placeholder((n,), name="A")
+B = te.placeholder((n,), name="B")
+C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+
+s = te.create_schedule(C.op)
+
+factor = 4
+
+outer, inner = s[C].split(C.op.axis[0], factor=factor)
+s[C].parallel(outer)
+s[C].vectorize(inner)
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# We've defined, scheduled, and compiled a vector addition operator, which we
+# were then able to execute on the TVM runtime. We can save the operator as a
+# library, which we can then load later using the TVM runtime.
+
+################################################################################
+# Targeting Vector Addition for GPUs (Optional)
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# TVM is capable of targeting multiple architectures. In the next example, we
+# will target compilation of the vector addition to GPUs
+
+# If you want to run this code, change ``run_cuda = True``
+run_cuda = False
+if run_cuda:
+
+# Change this target to the correct backend for you gpu. For example: cuda (NVIDIA GPUs), rocm (Radeon GPUS), OpenGL (???).
+    tgt_gpu = "cuda"
+
+# Recreate the schedule

Review comment:
       nit: indent?

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.

Review comment:
       nit: `*Relay*,` ?

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -163,52 +145,155 @@
 fadd(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Inspect the Generated Code
-# --------------------------
-# You can inspect the generated code in TVM. The result of tvm.build
-# is a TVM Module. fadd is the host module that contains the host wrapper,
-# it also contains a device module for the CUDA (GPU) function.
-#
-# The following code fetches the device module and prints the content code.
-#
-if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
-    dev_module = fadd.imported_modules[0]
-    print("-----GPU code-----")
-    print(dev_module.get_source())
-else:
-    print(fadd.get_source())
+################################################################################
+# Updating the Schedule to Use Paralleism
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Now that we've illustrated the fundamentals of TE, let's go deeper into what
+# schedules do, and how they can be used to optimize tensor expressions for
+# different architectures. A schedule is a series of steps that are applied to
+# an expression to transform it in a number of different ways. When a schedule
+# is applied to an expression in TE, the inputs and outputs remain the same,
+# but when compiled the implementation of the expression can change. This
+# tensor addition, in the default schedule, is run serially but is easy to
+# parallelize across all of the processor threads. We can apply the parallel
+# schedule operation to our computation.
 
-######################################################################
-# .. note:: Code Specialization
-#
-#   As you may have noticed, the declarations of A, B and C all
-#   take the same shape argument, n. TVM will take advantage of this
-#   to pass only a single shape argument to the kernel, as you will find in
-#   the printed device code. This is one form of specialization.
+s[C].parallel(C.op.axis[0])
+
+################################################################################
+# The ``tvm.lower`` command will generate the Intermediate Representation (IR)
+# of the TE, with the corresponding schedule. By lowering the expression as we
+# apply different schedule operations, we can see the effect of scheduling on
+# the ordering of the computation.
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# It's now possible for TVM to run these blocks on independent threads. Let's
+# compile and run this new schedule with the parallel operation applied:
+
+fadd_parallel = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd_parallel")
+fadd_parallel(a, b, c)
+
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Vectorization
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Modern CPUs also have the ability to perform SIMD operations on floating
+# point values, and we can apply another schedule to our computation expression
+# to take advantage of this. Accomplishing this requires multiple steps: first
+# we have to split the schedule into inner and outer loops using the split
+# scheduling primitive. The inner loops can use vectorization to use SIMD
+# instructions using the vectorize scheduling primitive, then the outer loops
+# can be parallelized using the parallel scheduling primitive. Choose the split
+# factor to be the number of threads on your CPU.
+
+# Recreate the schedule, since we modified it with the parallel operation in the previous example
+n = te.var("n")
+A = te.placeholder((n,), name="A")
+B = te.placeholder((n,), name="B")
+C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+
+s = te.create_schedule(C.op)
+
+factor = 4
+
+outer, inner = s[C].split(C.op.axis[0], factor=factor)
+s[C].parallel(outer)
+s[C].vectorize(inner)
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# We've defined, scheduled, and compiled a vector addition operator, which we
+# were then able to execute on the TVM runtime. We can save the operator as a
+# library, which we can then load later using the TVM runtime.
+
+################################################################################
+# Targeting Vector Addition for GPUs (Optional)
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# TVM is capable of targeting multiple architectures. In the next example, we
+# will target compilation of the vector addition to GPUs
+
+# If you want to run this code, change ``run_cuda = True``
+run_cuda = False
+if run_cuda:
+
+# Change this target to the correct backend for you gpu. For example: cuda (NVIDIA GPUs), rocm (Radeon GPUS), OpenGL (???).

Review comment:
       nit: fix the ???

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay applies several high-level optimization to the model, after which
+#    is runs the Relay Fusion Pass. To aid in the process of converting to

Review comment:
       what's the Relay Fusion pass?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org