You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/09 00:15:34 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #7612: [WIP][Docs] Getting Started With TVM Tutorial

tkonolige commented on a change in pull request #7612:
URL: https://github.com/apache/tvm/pull/7612#discussion_r589813112



##########
File path: tutorials/get_started/install.py
##########
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Installing TVM
+==============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Depending on your needs and your working environment, there are a few different
+methods for installing TVM. These include:
+    * Installing from the TLCPack Conda and Pip Packages
+    * Installing from source
+"""
+
+################################################################################
+# Installing from TLC Pack
+# ------------------------
+# TVM is packaged and distributed as part of the volunteer TLCPack community.

Review comment:
       Can you show the pip install command for tlcpack here?

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.

Review comment:
       You introduce `task` here. Maybe define it?

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.

Review comment:
       I don't think implement is the correct word here. Lower is what you want, but you define it later on. Maybe 'convert'?

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with
+#    a new representation, Tensor Intermediate Representation (TIR), that
+#    includes TE as a subset of TIR.
+#
+# 4. Search for an optimized schedule using *AutoTVM* or *AutoScheduler*.
+#    Tuning is defined as the process of searching for a schedule (an ordered
+#    notation) for the neural network to be compiled. There are couple of
+#    optimization options available, each requiring varying levels of user
+#    interaction. Both of these methods can draw from the TVM Operator
+#    Inventory (TOPI). TOPI includes pre-defined templates of common machine
+#    learning operations. The optimization options include:
+#
+#    - **AutoTVM**: The user specifies a search template for the schedule of a TE task,
+#      or TE subraph. AutoTVM directs the search of the parameter space defined by the
+#      template to produce an optimized configuration. AutoTVM requires users to
+#      define manually templates for each operator as part of the TOPI.
+#    - **Ansor/AutoSchedule**: Using a TVM Operator Inventory (TOPI) of operations,
+#      Ansor can automatically search an optimization space with much less
+#      intervention and guidance from the end user. Ansor depends on TE templates to
+#      guide the search.
+#
+# 5. Determing optimal schedule. After tuning, a schedule is determined to

Review comment:
       I would merge this into the previous step.

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with
+#    a new representation, Tensor Intermediate Representation (TIR), that
+#    includes TE as a subset of TIR.
+#
+# 4. Search for an optimized schedule using *AutoTVM* or *AutoScheduler*.
+#    Tuning is defined as the process of searching for a schedule (an ordered
+#    notation) for the neural network to be compiled. There are couple of
+#    optimization options available, each requiring varying levels of user
+#    interaction. Both of these methods can draw from the TVM Operator
+#    Inventory (TOPI). TOPI includes pre-defined templates of common machine
+#    learning operations. The optimization options include:
+#
+#    - **AutoTVM**: The user specifies a search template for the schedule of a TE task,
+#      or TE subraph. AutoTVM directs the search of the parameter space defined by the
+#      template to produce an optimized configuration. AutoTVM requires users to
+#      define manually templates for each operator as part of the TOPI.
+#    - **Ansor/AutoSchedule**: Using a TVM Operator Inventory (TOPI) of operations,
+#      Ansor can automatically search an optimization space with much less
+#      intervention and guidance from the end user. Ansor depends on TE templates to
+#      guide the search.
+#
+# 5. Determing optimal schedule. After tuning, a schedule is determined to
+#    optimize on. Regardless if it is AutoTVM or AutoSchedule, schedule records in
+#    JSON format are produced. Afterwards, the best schedule found is chosen to
+#    determine how to optimize each layer of the neural network.
+#
+# 6. Lower to hardware specific compiler.  TVM tuning operates by computing
+#    performance metrics for different operator configurations on the target
+#    hardware, then choosing the best configuration in the final code generation
+#    phase. This code generation is meant to produce an optimized model that can
+#    be deployed into production. TVM supports a number of different compiler
+#    backends including:
+#
+#    - LLVM, which can target arbitrary microprocessor architecture including
+#      standard x86 and ARM processors, AMDGPU and NVPTX code generation, and any
+#      other platform supported by LLVM.
+#    - Source-to-source compilation, such as with NVCC, NVIDIA's compiler.

Review comment:
       This isn't really source to source because TIR is not the source language.

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with
+#    a new representation, Tensor Intermediate Representation (TIR), that
+#    includes TE as a subset of TIR.
+#
+# 4. Search for an optimized schedule using *AutoTVM* or *AutoScheduler*.
+#    Tuning is defined as the process of searching for a schedule (an ordered
+#    notation) for the neural network to be compiled. There are couple of
+#    optimization options available, each requiring varying levels of user
+#    interaction. Both of these methods can draw from the TVM Operator
+#    Inventory (TOPI). TOPI includes pre-defined templates of common machine

Review comment:
       I think mentions of TOPI should go above. TOPI is a collection of implementations of relay operators. The implementations may need to be tuned, in which case autotvm and auto_scheduler come into play.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target
+# and host, and the name of the function we are producing. The result of the
+# output is a type-erased function that can be called directly from Python.
+#
+# In the following line, we use tvm.build to create a function. The build

Review comment:
       I believe `tvm.build` is the old style interface. You want `relay.build`.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target
+# and host, and the name of the function we are producing. The result of the
+# output is a type-erased function that can be called directly from Python.
+#
+# In the following line, we use tvm.build to create a function. The build
+# function takes the schedule, the desired signature of the function (including
+# the inputs and outputs) as well as target language we want to compile to.
+
+fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+
+################################################################################
+# Let's run the function, and compare the output to the same computation in
+# numpy. We begin by creating a context, which is a runtime that TVM can

Review comment:
       ```suggestion
   # numpy. We begin by creating a context, which is a device (CPU, GPU) that TVM can
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.

Review comment:
       ```suggestion
   # The default schedule will compute C in by iterating in row major order.
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target
+# and host, and the name of the function we are producing. The result of the
+# output is a type-erased function that can be called directly from Python.
+#
+# In the following line, we use tvm.build to create a function. The build
+# function takes the schedule, the desired signature of the function (including
+# the inputs and outputs) as well as target language we want to compile to.
+
+fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+
+################################################################################
+# Let's run the function, and compare the output to the same computation in
+# numpy. We begin by creating a context, which is a runtime that TVM can
+# compile the schedule to. In this case the context is an LLVM CPU target. We
+# can then initialize the tensors in our context and perform the custom
+# addition operation. To verify that the computation is correct, we can compare
+# the result of the output of the c tensor to the same computation performed by
+# numpy.
+
+ctx = tvm.context(tgt, 0)
+
+n = 1024
+a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
+b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
+c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+fadd(a, b, c)
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Paralleism
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Now that we've illustrated the fundamentals of TE, let's go deeper into what
+# schedules do, and how they can be used to optimize tensor expressions for
+# different architectures. A schedule is a series of steps that are applied to
+# a TE to transform it in a number of different ways. When a schedule is

Review comment:
       I'm not sure how people refer to a single TE expression. I use "TE expression" to differentiate it from "TE" which is the language.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -223,22 +295,21 @@
 cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
 print(temp.listdir())
 
-######################################################################
+################################################################################
 # .. note:: Module Storage Format
 #
-#   The CPU (host) module is directly saved as a shared library (.so).
-#   There can be multiple customized formats of the device code.
-#   In our example, the device code is stored in ptx, as well as a meta
-#   data json file. They can be loaded and linked separately via import.
-#
+#   The CPU (host) module is directly saved as a shared library (.so). There

Review comment:
       Is everything about saving and loading code relevant to this tutorial? Should the be a separate tutorial about packaging?

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -163,52 +253,34 @@
 fadd(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Inspect the Generated Code
-# --------------------------
-# You can inspect the generated code in TVM. The result of tvm.build
-# is a TVM Module. fadd is the host module that contains the host wrapper,
-# it also contains a device module for the CUDA (GPU) function.
+################################################################################
+# Inspect the Generated GPU Code
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# You can inspect the generated code in TVM. The result of tvm.build is a TVM
+# Module. fadd is the host module that contains the host wrapper, it also
+# contains a device module for the CUDA (GPU) function.
 #
 # The following code fetches the device module and prints the content code.
-#
-if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
+
+if tgt_gpu == "cuda" or tgt_gpu == "rocm" or tgt_gpu.startswith("opencl"):
     dev_module = fadd.imported_modules[0]
     print("-----GPU code-----")
     print(dev_module.get_source())
 else:
     print(fadd.get_source())
 
-######################################################################
-# .. note:: Code Specialization
-#
-#   As you may have noticed, the declarations of A, B and C all
-#   take the same shape argument, n. TVM will take advantage of this
-#   to pass only a single shape argument to the kernel, as you will find in
-#   the printed device code. This is one form of specialization.
-#
-#   On the host side, TVM will automatically generate check code
-#   that checks the constraints in the parameters. So if you pass
-#   arrays with different shapes into fadd, an error will be raised.
-#
-#   We can do more specializations. For example, we can write
-#   :code:`n = tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`,
-#   in the computation declaration. The generated function will
-#   only take vectors with length 1024.
-#
-
-######################################################################
-# Save Compiled Module
-# --------------------
-# Besides runtime compilation, we can save the compiled modules into
-# a file and load them back later. This is called ahead of time compilation.
+################################################################################
+# Saving and Loading Compiled Modules
+# -----------------------------------
+# Besides runtime compilation, we can save the compiled modules into a file and
+# load them back later. This is called ahead of time compilation.

Review comment:
       ```suggestion
   # load them back later.
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies

Review comment:
       What does it mean to compute something at the root?

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**

Review comment:
       ```suggestion
   # **Matrix multiplication is a compute intensive operation. There are two important optimizations for good CPU performance:**
   ```

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with

Review comment:
       ```suggestion
   #    specific hardware target can be made. Work is underway to replace TE with
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit
+#    rate. This requires us to transform the origin memory access pattern to the
+#    pattern fits the cache policy.
+# 2. SIMD (Single instruction multi-data), also known as the vector processing
+#    unit. Every time, a small batch of data, rather than a single grid, will be
+#    processed. This requires us to transform the data access pattern in the loop
+#    body in uniform pattern so that the LLVM backend can lower it to SIMD.
+#
+# The techniques used in this tutorial are a subset of tricks mentioned in this
+# `repository <https://github.com/flame/how-to-optimize-gemm>`_. Some of them
+# have been applied by TVM abstraction automatically, but some of them cannot
+# be simply applied due to TVM constraints.
+#
+# All the experiment results mentioned below, are executed on 2015 15" MacBook
+# equipped with Intel i7-4770HQ CPU. The cache line size should  be 64 bytes for
+# all the x86 CPUs.
+
+################################################################################
+# Preparation and Performance Baseline
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# We begin by collecting performance data on the `numpy` implementation of
+# matrix multiplication.
+
+import tvm
+import tvm.testing
+from tvm import te
+import numpy
+import timeit
+
+# The size of the matrix
+# (M, K) x (K, N)
+# You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.
+M = 1024
+K = 1024
+N = 1024
+
+# The default tensor data type in tvm
+dtype = "float32"
+
+# using Intel AVX2 (Advanced Vector Extensions) ISA for SIMD
+# To get the best performance, please change the following line
+# to llvm -mcpu=core-avx2, or specific type of CPU you use
+target = "llvm"
+ctx = tvm.context(target, 0)
+
+# Random generated tensor for testing
+a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), ctx)
+b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), ctx)
+
+# Repeatedly perform a matrix multiplication to get a performance baseline
+# for the default numpy implementation
+np_repeat = 100
+np_runing_time = timeit.timeit(
+    setup="import numpy\n"
+    "M = " + str(M) + "\n"
+    "K = " + str(K) + "\n"
+    "N = " + str(N) + "\n"
+    'dtype = "float32"\n'
+    "a = numpy.random.rand(M, K).astype(dtype)\n"
+    "b = numpy.random.rand(K, N).astype(dtype)\n",
+    stmt="answer = numpy.dot(a, b)",
+    number=np_repeat,
+)
+print("Numpy running time: %f" % (np_runing_time / np_repeat))
+
+answer = numpy.dot(a.asnumpy(), b.asnumpy())
+
+################################################################################
+# Now, write a basic matrix multiplication using TVM TE and verify that it

Review comment:
       ```suggestion
   # Now we write a basic matrix multiplication using TVM TE and verify that it
   ```

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with
+#    a new representation, Tensor Intermediate Representation (TIR), that
+#    includes TE as a subset of TIR.
+#
+# 4. Search for an optimized schedule using *AutoTVM* or *AutoScheduler*.
+#    Tuning is defined as the process of searching for a schedule (an ordered
+#    notation) for the neural network to be compiled. There are couple of

Review comment:
       What is an ordered notation?

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it

Review comment:
       ```suggestion
   #    Upon completing the import and high level optimizations, the next step is
   ```

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the

Review comment:
       ```suggestion
   #    Relay (or more specifically, its fusion pass) is in charge of splitting the
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -1,32 +1,56 @@
 # Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
+# or more contributor license agreements. See the NOTICE file
 # distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
+# regarding copyright ownership. The ASF licenses this file
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
+# with the License. You may obtain a copy of the License at
 #
 #   http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
+# KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations
 # under the License.
 """
 .. _tutorial-tensor-expr-get-started:
 
-Get Started with Tensor Expression
-==================================
+Working with Operators Using Tensor Expressions
+===============================================
 **Author**: `Tianqi Chen <https://tqchen.github.io>`_
 
-This is an introductory tutorial to the Tensor expression language in TVM.
-TVM uses a domain specific tensor expression for efficient kernel construction.
+In this tutorial we will turn our attention to how TVM works with Template
+Expressions (TE) to create a space to search for performant configurations. TE
+describes tensor computations in a pure functional language (that is each
+expression has no side effects). When viewed in context of the TVM as a whole,
+Relay describes a computation as a set of operators, and each of these
+operators can be represented as a TE expression where each TE expression takes
+an input tensor and produces an output tensor. It's important to note that the
+tensor isn't necessarily a fully materialized array, rather it is a
+representation of a computation. If you want to produce a computation from a
+TE, you will need to use the scheduling features of TVM to produce a
+computation.
 
-In this tutorial, we will demonstrate the basic workflow to use
-the tensor expression language.
+This is an introductory tutorial to the Tensor expression language in TVM. TVM
+uses a domain specific tensor expression for efficient kernel construction. We
+will demonstrate the basic workflow to use the tensor expression language with
+two examples. The first example introduces TE and scheduling with vector
+addition. The second expands on these concepts with a step-by-step optimization
+of a matrix multiplication with TW. This matrix multiplication example will

Review comment:
       ```suggestion
   of a matrix multiplication with TE. This matrix multiplication example will
   ```

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with
+#    a new representation, Tensor Intermediate Representation (TIR), that
+#    includes TE as a subset of TIR.

Review comment:
       I think this will be confusing to new people. They don't have to know about the future of TVM to use it.

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with
+#    a new representation, Tensor Intermediate Representation (TIR), that
+#    includes TE as a subset of TIR.
+#
+# 4. Search for an optimized schedule using *AutoTVM* or *AutoScheduler*.
+#    Tuning is defined as the process of searching for a schedule (an ordered

Review comment:
       ```suggestion
   #    Tuning is the process of searching for a schedule (an ordered
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.

Review comment:
       "A schedule is a set of transformations that one set of loops into another."
   
   I would expand on this a lot. Scheduling is confusing.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -1,32 +1,56 @@
 # Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
+# or more contributor license agreements. See the NOTICE file
 # distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
+# regarding copyright ownership. The ASF licenses this file
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
+# with the License. You may obtain a copy of the License at
 #
 #   http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
+# KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations
 # under the License.
 """
 .. _tutorial-tensor-expr-get-started:
 
-Get Started with Tensor Expression
-==================================
+Working with Operators Using Tensor Expressions
+===============================================
 **Author**: `Tianqi Chen <https://tqchen.github.io>`_
 
-This is an introductory tutorial to the Tensor expression language in TVM.
-TVM uses a domain specific tensor expression for efficient kernel construction.
+In this tutorial we will turn our attention to how TVM works with Template
+Expressions (TE) to create a space to search for performant configurations. TE
+describes tensor computations in a pure functional language (that is each
+expression has no side effects). When viewed in context of the TVM as a whole,
+Relay describes a computation as a set of operators, and each of these
+operators can be represented as a TE expression where each TE expression takes
+an input tensor and produces an output tensor. It's important to note that the
+tensor isn't necessarily a fully materialized array, rather it is a
+representation of a computation. If you want to produce a computation from a
+TE, you will need to use the scheduling features of TVM to produce a
+computation.
 
-In this tutorial, we will demonstrate the basic workflow to use
-the tensor expression language.
+This is an introductory tutorial to the Tensor expression language in TVM. TVM
+uses a domain specific tensor expression for efficient kernel construction. We
+will demonstrate the basic workflow to use the tensor expression language with
+two examples. The first example introduces TE and scheduling with vector

Review comment:
       ```suggestion
   uses a domain specific tensor expression for efficient kernel construction. We
   will demonstrate the basic workflow with two examples of using the tensor expression
   language. The first example introduces TE and scheduling with vector
   ```

##########
File path: tutorials/get_started/introduction.py
##########
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Introduction
+============
+**Authors**:
+`Jocelyn Shiue <https://github.com/>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
+
+Apache TVM is an open source machine learning compiler framework for CPUs,
+GPUs, and machine learning accelerators. It aims to enable machine learning
+engineers to optimize and run computations efficiently on any hardware backend.
+The purpose of this tutorial is to take a guided tour through all of the major
+features of TVM by defining and demonstrating key concepts. A new user should
+be able to work through the tutorial from start to finish and be able to
+operate TVM for automatic model optimization, while having a basic
+understanding of the TVM architecture and how it works.
+
+Contents
+--------
+
+#. :doc:`Introduction <introduction>`
+#. :doc:`Installing TVM <install>`
+#. :doc:`Compiling and Optimizing a Model with TVMC <tvmc_command_line_driver>`
+#. :doc:`Compiling and Optimizing a Model with the Python AutoScheduler <auto_tuning_with_python>`
+#. :doc:`Working with Operators Using Tensor Expressions <tensor_expr_get_started>`
+#. :doc:`Optimizing Operators with Templates and AutoTVM <autotvm_matmul>`
+#. :doc:`Optimizing Operators with AutoScheduling <tune_matmul_x86>`
+#. :doc:`Cross Compilation and Remote Procedure Calls (RPC) <cross_compilation_and_rpc>`
+#. :doc:`Compiling Deep Learning Models for GPUs <relay_quick_start>`
+"""
+
+################################################################################
+# An Overview of TVM and Model Optimization
+# =========================================
+#
+# The diagram below illustrates the steps a machine model takes as it is
+# transformed with the TVM optimizing compiler framework.
+#
+# .. image:: /_static/img/tvm.png
+#   :width: 100%
+#   :alt: A High Level View of TVM
+#
+# 1. Import the model from a framework like *Tensorflow*, *Pytorch*, or *Onnx*.
+#    The importer layer is where TVM can ingest models from other frameworks, like
+#    ONNX, Tensorflow, or PyTorch. The level of support that TVM offers for each
+#    frontend varies as we are constantly improving the open source project. If
+#    you're having issues importing your model into TVM, you may want to try
+#    converting it to ONNX.
+#
+# 2. Translate to *Relay,* TVM's high level model language.
+#    A model that has been imported into TVM is represented in Relay. Relay is a
+#    functional language and intermediate representation (IR) for neural networks.
+#    It has support for:
+#
+#    - Traditional data flow-style representations
+#    - Functional-style scoping, let-binding which makes it a fully featured
+#      differentiable language
+#    - Ability to allow the user to mix the two programming styles
+#
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task.
+#
+# 3. Lower to *TE*, tensor expressions that define the *computational
+#    operations* of the neural network.
+#    Upon completing the import and high level optimizations, the next step it
+#    to decide how to implement the Relay representation to a hardware target.
+#    Relay (or more detailedly, its fusion pass) is in charge of splitting the
+#    neural network into small subgraphs, each of which is a task. Here
+#    lowering means going lowering into TE tasks. The first step is to lower
+#    each task within the Relay model into a tensor expression. The tensor
+#    expressions describe the operations, aka functions, contained within a
+#    neural network. Once transformed into TE, further optimizations for the
+#    specific hardware target can be made. Work is underway to replace TR with
+#    a new representation, Tensor Intermediate Representation (TIR), that
+#    includes TE as a subset of TIR.
+#
+# 4. Search for an optimized schedule using *AutoTVM* or *AutoScheduler*.
+#    Tuning is defined as the process of searching for a schedule (an ordered
+#    notation) for the neural network to be compiled. There are couple of
+#    optimization options available, each requiring varying levels of user
+#    interaction. Both of these methods can draw from the TVM Operator
+#    Inventory (TOPI). TOPI includes pre-defined templates of common machine
+#    learning operations. The optimization options include:
+#
+#    - **AutoTVM**: The user specifies a search template for the schedule of a TE task,
+#      or TE subraph. AutoTVM directs the search of the parameter space defined by the
+#      template to produce an optimized configuration. AutoTVM requires users to
+#      define manually templates for each operator as part of the TOPI.
+#    - **Ansor/AutoSchedule**: Using a TVM Operator Inventory (TOPI) of operations,
+#      Ansor can automatically search an optimization space with much less
+#      intervention and guidance from the end user. Ansor depends on TE templates to
+#      guide the search.
+#
+# 5. Determing optimal schedule. After tuning, a schedule is determined to
+#    optimize on. Regardless if it is AutoTVM or AutoSchedule, schedule records in
+#    JSON format are produced. Afterwards, the best schedule found is chosen to
+#    determine how to optimize each layer of the neural network.
+#
+# 6. Lower to hardware specific compiler.  TVM tuning operates by computing

Review comment:
       I would move the "TVM tuning operates by..." to the previous step. Then you can just say something about "TE along with optimized schedules (if provided by tuning) are lowered into device-specific code by the relevant backend".

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.

Review comment:
       I think it would be helpful to spell this out more. Like provide a couple different ways we could do the computation.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the

Review comment:
       A little more detail on what `compute` does would be helpful here.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -1,32 +1,56 @@
 # Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
+# or more contributor license agreements. See the NOTICE file
 # distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
+# regarding copyright ownership. The ASF licenses this file
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
+# with the License. You may obtain a copy of the License at
 #
 #   http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
+# KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations
 # under the License.
 """
 .. _tutorial-tensor-expr-get-started:
 
-Get Started with Tensor Expression
-==================================
+Working with Operators Using Tensor Expressions
+===============================================
 **Author**: `Tianqi Chen <https://tqchen.github.io>`_
 
-This is an introductory tutorial to the Tensor expression language in TVM.
-TVM uses a domain specific tensor expression for efficient kernel construction.
+In this tutorial we will turn our attention to how TVM works with Template
+Expressions (TE) to create a space to search for performant configurations. TE
+describes tensor computations in a pure functional language (that is each
+expression has no side effects). When viewed in context of the TVM as a whole,
+Relay describes a computation as a set of operators, and each of these
+operators can be represented as a TE expression where each TE expression takes
+an input tensor and produces an output tensor. It's important to note that the
+tensor isn't necessarily a fully materialized array, rather it is a
+representation of a computation. If you want to produce a computation from a
+TE, you will need to use the scheduling features of TVM to produce a
+computation.

Review comment:
       ```suggestion
   TE, you will need to use the scheduling features of TVM.
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target
+# and host, and the name of the function we are producing. The result of the
+# output is a type-erased function that can be called directly from Python.
+#
+# In the following line, we use tvm.build to create a function. The build
+# function takes the schedule, the desired signature of the function (including
+# the inputs and outputs) as well as target language we want to compile to.
+
+fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+
+################################################################################
+# Let's run the function, and compare the output to the same computation in
+# numpy. We begin by creating a context, which is a runtime that TVM can
+# compile the schedule to. In this case the context is an LLVM CPU target. We
+# can then initialize the tensors in our context and perform the custom
+# addition operation. To verify that the computation is correct, we can compare
+# the result of the output of the c tensor to the same computation performed by
+# numpy.
+
+ctx = tvm.context(tgt, 0)
+
+n = 1024
+a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
+b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
+c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+fadd(a, b, c)
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Paralleism
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Now that we've illustrated the fundamentals of TE, let's go deeper into what
+# schedules do, and how they can be used to optimize tensor expressions for
+# different architectures. A schedule is a series of steps that are applied to
+# a TE to transform it in a number of different ways. When a schedule is
+# applied to a TE, the inputs and outputs remain the same, but when compiled
+# the implementation of the expression can change. This tensor addition, in the

Review comment:
       "the implementation of the expression can change" kinda doesn't say anything. Maybe talk about loop ordering?

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.

Review comment:
       You may want to say how a user can find this configuration as it is a question that comes up often.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language

Review comment:
       "With a TE expression and a schedule we can produce runnable code for our target language and architecture"

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target

Review comment:
       ```suggestion
   # and architecture, in this case LLVM to a CPU. We provide TVM with the
   # schedule, a list of the TE expressions that are in the schedule, the target
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target
+# and host, and the name of the function we are producing. The result of the
+# output is a type-erased function that can be called directly from Python.
+#
+# In the following line, we use tvm.build to create a function. The build
+# function takes the schedule, the desired signature of the function (including
+# the inputs and outputs) as well as target language we want to compile to.
+
+fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+
+################################################################################
+# Let's run the function, and compare the output to the same computation in
+# numpy. We begin by creating a context, which is a runtime that TVM can
+# compile the schedule to. In this case the context is an LLVM CPU target. We
+# can then initialize the tensors in our context and perform the custom
+# addition operation. To verify that the computation is correct, we can compare
+# the result of the output of the c tensor to the same computation performed by
+# numpy.
+
+ctx = tvm.context(tgt, 0)
+
+n = 1024
+a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
+b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
+c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+fadd(a, b, c)
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Paralleism
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Now that we've illustrated the fundamentals of TE, let's go deeper into what
+# schedules do, and how they can be used to optimize tensor expressions for
+# different architectures. A schedule is a series of steps that are applied to
+# a TE to transform it in a number of different ways. When a schedule is
+# applied to a TE, the inputs and outputs remain the same, but when compiled
+# the implementation of the expression can change. This tensor addition, in the
+# default schedule, is run serially but is easy to parallelize across all of
+# the processor threads. We can apply the parallel schedule operation to our
+# computation.
+
+s[C].parallel(C.op.axis[0])
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# It's now possible for TVM to run these blocks on independent threads. Let's
+# compile and run this new schedule with the parallel operation applied:
+
+fadd_parallel = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd_parallel")
+fadd_parallel(a, b, c)
+
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Vectorization
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Modern CPUs also have the ability to perform SIMD operations on floating
+# point values, and we can apply another schedule to our computation expression
+# to take advantage of this. Accomplishing this requires multiple steps: first
+# we have to split the schedule into inner and outer loops using the split
+# scheduling primitive. The inner loops can use vectorization to use SIMD
+# instructions using the vectorize scheduling primitive, then the outer loops
+# can be parallelized using the parallel scheduling primitive. Choose the split
+# factor to be the number of threads on your CPU.
+
+# Recreate the schedule, since we modified it with the parallel operation in the previous example
+n = te.var("n")
+A = te.placeholder((n,), name="A")
+B = te.placeholder((n,), name="B")
+C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+
+s = te.create_schedule(C.op)
+
+factor = 4
+
+outer, inner = s[C].split(C.op.axis[0], factor=factor)
+s[C].parallel(outer)
+s[C].vectorize(inner)
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# We've defined, scheduled, and compiled a vector addition operator, which we
+# were then able to execute on the TVM runtime. We can save the operator as a
+# library, which we can then load later using the TVM runtime.
+
+################################################################################
+# Targeting Vector Addition for GPUs (Optional)
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# TVM is capable of targeting multiple architectures. In the next example, we
+# will target compilation of the vector addition to GPUs
+
+# Change this target respective GPU if gpu is enabled Ex: cuda, opencl, rocm

Review comment:
       ```suggestion
   # Change this target to the correct backend for you gpu. For example: cuda (NVIDIA GPUs), rocm (Radeon GPUS), OpenGL (???).
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -223,22 +295,21 @@
 cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
 print(temp.listdir())
 
-######################################################################
+################################################################################
 # .. note:: Module Storage Format
 #
-#   The CPU (host) module is directly saved as a shared library (.so).
-#   There can be multiple customized formats of the device code.
-#   In our example, the device code is stored in ptx, as well as a meta
-#   data json file. They can be loaded and linked separately via import.
-#
+#   The CPU (host) module is directly saved as a shared library (.so). There
+#   can be multiple customized formats of the device code. In our example, the
+#   device code is stored in ptx, as well as a meta data json file. They can be
+#   loaded and linked separately via import.
 
-######################################################################
+################################################################################
 # Load Compiled Module
-# --------------------
-# We can load the compiled module from the file system and run the code.
-# The following code loads the host and device module separately and
-# re-links them together. We can verify that the newly loaded function works.
-#
+# ~~~~~~~~~~~~~~~~~~~~
+# We can load the compiled module from the file system and run the code. The
+# following code loads the host and device module separately and re-links them

Review comment:
       ```suggestion
   # following code loads the host and device module separately and links them
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies

Review comment:
       ```suggestion
   #   - compute_at: by default, TVM will compute tensors at the root by default. compute_at specifies
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.

Review comment:
       ```suggestion
   # just 18 lines of python code TVM speeds up a common matrix multiplication operation by 18x.
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target
+# and host, and the name of the function we are producing. The result of the
+# output is a type-erased function that can be called directly from Python.
+#
+# In the following line, we use tvm.build to create a function. The build
+# function takes the schedule, the desired signature of the function (including
+# the inputs and outputs) as well as target language we want to compile to.
+
+fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+
+################################################################################
+# Let's run the function, and compare the output to the same computation in
+# numpy. We begin by creating a context, which is a runtime that TVM can
+# compile the schedule to. In this case the context is an LLVM CPU target. We
+# can then initialize the tensors in our context and perform the custom
+# addition operation. To verify that the computation is correct, we can compare
+# the result of the output of the c tensor to the same computation performed by
+# numpy.
+
+ctx = tvm.context(tgt, 0)
+
+n = 1024
+a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
+b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
+c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+fadd(a, b, c)
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Paralleism
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Now that we've illustrated the fundamentals of TE, let's go deeper into what
+# schedules do, and how they can be used to optimize tensor expressions for
+# different architectures. A schedule is a series of steps that are applied to
+# a TE to transform it in a number of different ways. When a schedule is
+# applied to a TE, the inputs and outputs remain the same, but when compiled
+# the implementation of the expression can change. This tensor addition, in the
+# default schedule, is run serially but is easy to parallelize across all of
+# the processor threads. We can apply the parallel schedule operation to our
+# computation.
+
+s[C].parallel(C.op.axis[0])
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# It's now possible for TVM to run these blocks on independent threads. Let's
+# compile and run this new schedule with the parallel operation applied:
+
+fadd_parallel = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd_parallel")
+fadd_parallel(a, b, c)
+
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Vectorization
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Modern CPUs also have the ability to perform SIMD operations on floating
+# point values, and we can apply another schedule to our computation expression
+# to take advantage of this. Accomplishing this requires multiple steps: first
+# we have to split the schedule into inner and outer loops using the split
+# scheduling primitive. The inner loops can use vectorization to use SIMD
+# instructions using the vectorize scheduling primitive, then the outer loops
+# can be parallelized using the parallel scheduling primitive. Choose the split
+# factor to be the number of threads on your CPU.
+
+# Recreate the schedule, since we modified it with the parallel operation in the previous example
+n = te.var("n")
+A = te.placeholder((n,), name="A")
+B = te.placeholder((n,), name="B")
+C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+
+s = te.create_schedule(C.op)
+
+factor = 4
+
+outer, inner = s[C].split(C.op.axis[0], factor=factor)
+s[C].parallel(outer)
+s[C].vectorize(inner)
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# We've defined, scheduled, and compiled a vector addition operator, which we
+# were then able to execute on the TVM runtime. We can save the operator as a
+# library, which we can then load later using the TVM runtime.
+
+################################################################################
+# Targeting Vector Addition for GPUs (Optional)
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# TVM is capable of targeting multiple architectures. In the next example, we
+# will target compilation of the vector addition to GPUs
+
+# Change this target respective GPU if gpu is enabled Ex: cuda, opencl, rocm
+tgt_gpu = "cuda"
+
+# Recreate the schedule
+n = te.var("n")
+A = te.placeholder((n,), name="A")
+B = te.placeholder((n,), name="B")
+C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+print(type(C))
+
+s = te.create_schedule(C.op)
+
 bx, tx = s[C].split(C.op.axis[0], factor=64)
 
-######################################################################
-# Finally we bind the iteration axis bx and tx to threads in the GPU
-# compute grid. These are GPU specific constructs that allow us
-# to generate code that runs on GPU.
-#
-if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
+################################################################################
+# Finally we bind the iteration axis bx and tx to threads in the GPU compute
+# grid. These are GPU specific constructs that allow us to generate code that
+# runs on GPU.
+
+if tgt_gpu == "cuda" or tgt_gpu == "rocm" or tgt_gpu.startswith("opencl"):
     s[C].bind(bx, te.thread_axis("blockIdx.x"))
     s[C].bind(tx, te.thread_axis("threadIdx.x"))
 
-######################################################################
-# Compilation
-# -----------
-# After we have finished specifying the schedule, we can compile it
-# into a TVM function. By default TVM compiles into a type-erased
-# function that can be directly called from the python side.
-#
-# In the following line, we use tvm.build to create a function.
-# The build function takes the schedule, the desired signature of the
-# function (including the inputs and outputs) as well as target language
-# we want to compile to.
-#
-# The result of compilation fadd is a GPU device function (if GPU is
-# involved) as well as a host wrapper that calls into the GPU
-# function.  fadd is the generated host wrapper function, it contains
-# a reference to the generated device function internally.
-#
-fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+fadd = tvm.build(s, [A, B, C], tgt_gpu, target_host=tgt_host, name="myadd")
 
-######################################################################
-# Run the Function
-# ----------------
-# The compiled TVM function is exposes a concise C API
-# that can be invoked from any language.
-#
-# We provide a minimal array API in python to aid quick testing and prototyping.
-# The array API is based on the `DLPack <https://github.com/dmlc/dlpack>`_ standard.
+################################################################################
+# The compiled TVM function is exposes a concise C API that can be invoked from
+# any language.
 #
-# - We first create a GPU context.
-# - Then tvm.nd.array copies the data to the GPU.
-# - fadd runs the actual computation.
-# - asnumpy() copies the GPU array back to the CPU and we can use this to verify correctness
+# We provide a minimal array API in python to aid quick testing and
+# prototyping. The array API is based on the DLPack standard.
 #
-ctx = tvm.context(tgt, 0)
+# We first create a GPU context. Then tvm.nd.array copies the data to the GPU.
+# fadd runs the actual computation. asnumpy() copies the GPU array back to the
+# CPU and we can use this to verify correctness

Review comment:
       ```suggestion
   # We first create a GPU context. Then tvm.nd.array copies the data to the GPU,
   # fadd runs the actual computation, and asnumpy() copies the GPU array back to the
   # CPU (so we can verify correctness).
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -37,124 +61,190 @@
 # Global declarations of environment.
 
 tgt_host = "llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
-tgt = "cuda"
 
-######################################################################
-# Vector Add Example
-# ------------------
-# In this tutorial, we will use a vector addition example to demonstrate
-# the workflow.
-#
+# You will get better performance if you can identify the CPU you are targeting and specify it.
+# For example, ``tgt = "llvm -mcpu=broadwell``
+tgt = "llvm"
 
 ######################################################################
-# Describe the Computation
-# ------------------------
-# As a first step, we need to describe our computation.
-# TVM adopts tensor semantics, with each intermediate result
-# represented as a multi-dimensional array. The user needs to describe
-# the computation rule that generates the tensors.
-#
-# We first define a symbolic variable n to represent the shape.
-# We then define two placeholder Tensors, A and B, with given shape (n,)
-#
-# We then describe the result tensor C, with a compute operation.  The
-# compute function takes the shape of the tensor, as well as a lambda
-# function that describes the computation rule for each position of
-# the tensor.
-#
-# No computation happens during this phase, as we are only declaring how
-# the computation should be done.
-#
+# Describing the Vector Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We describe a vector addition computation. TVM adopts tensor semantics, with
+# each intermediate result represented as a multi-dimensional array. The user
+# needs to describe the computation rule that generates the tensors. We first
+# define a symbolic variable n to represent the shape. We then define two
+# placeholder Tensors, A and B, with given shape (n,). We then describe the
+# result tensor C, with a compute operation. The compute function takes the
+# shape of the tensor, as well as a lambda function that describes the
+# computation rule for each position of the tensor. Note that while n is a
+# variable, it defines a consistent shape between the A, B and C tensors.
+# Remember, no actual computation happens during this phase, as we are only
+# declaring how the computation should be done.
+
 n = te.var("n")
 A = te.placeholder((n,), name="A")
 B = te.placeholder((n,), name="B")
 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
 print(type(C))
 
 ######################################################################
-# Schedule the Computation
-# ------------------------
-# While the above lines describe the computation rule, we can compute
-# C in many ways since the axis of C can be computed in a data
-# parallel manner.  TVM asks the user to provide a description of the
-# computation called a schedule.
+# Create a Default Schedule for the Computation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# While the above lines describe the computation rule, we can compute C in many
+# ways since the axis of C can be computed in a data parallel manner. TVM asks
+# the user to provide a description of the computation called a schedule.
 #
-# A schedule is a set of transformation of computation that transforms
-# the loop of computations in the program.
+# A schedule is a set of transformation of computation that transforms the loop
+# of computations in the program.
 #
-# After we construct the schedule, by default the schedule computes
-# C in a serial manner in a row-major order.
+# By default, TVM will compute a schedule for C in a serial manner using
+# row-major order.
 #
 # .. code-block:: c
 #
 #   for (int i = 0; i < n; ++i) {
 #     C[i] = A[i] + B[i];
 #   }
-#
+
 s = te.create_schedule(C.op)
 
 ######################################################################
-# We used the split construct to split the first axis of C,
-# this will split the original iteration axis into product of
-# two iterations. This is equivalent to the following code.
-#
-# .. code-block:: c
-#
-#   for (int bx = 0; bx < ceil(n / 64); ++bx) {
-#     for (int tx = 0; tx < 64; ++tx) {
-#       int i = bx * 64 + tx;
-#       if (i < n) {
-#         C[i] = A[i] + B[i];
-#       }
-#     }
-#   }
-#
+# Compile and Evaluate the Default Schedule
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# With the schedule created, we can now compile it down to our target language
+# and architecture, in this case LLVM to a CPU. We provide TVM with the basic
+# schedule, a list of the TE expressions that are in the schedule, the target
+# and host, and the name of the function we are producing. The result of the
+# output is a type-erased function that can be called directly from Python.
+#
+# In the following line, we use tvm.build to create a function. The build
+# function takes the schedule, the desired signature of the function (including
+# the inputs and outputs) as well as target language we want to compile to.
+
+fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
+
+################################################################################
+# Let's run the function, and compare the output to the same computation in
+# numpy. We begin by creating a context, which is a runtime that TVM can
+# compile the schedule to. In this case the context is an LLVM CPU target. We
+# can then initialize the tensors in our context and perform the custom
+# addition operation. To verify that the computation is correct, we can compare
+# the result of the output of the c tensor to the same computation performed by
+# numpy.
+
+ctx = tvm.context(tgt, 0)
+
+n = 1024
+a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
+b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
+c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+fadd(a, b, c)
+tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+################################################################################
+# Updating the Schedule to Use Paralleism
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Now that we've illustrated the fundamentals of TE, let's go deeper into what
+# schedules do, and how they can be used to optimize tensor expressions for
+# different architectures. A schedule is a series of steps that are applied to
+# a TE to transform it in a number of different ways. When a schedule is
+# applied to a TE, the inputs and outputs remain the same, but when compiled
+# the implementation of the expression can change. This tensor addition, in the
+# default schedule, is run serially but is easy to parallelize across all of
+# the processor threads. We can apply the parallel schedule operation to our
+# computation.
+
+s[C].parallel(C.op.axis[0])
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))

Review comment:
       I would explain what tvm.lower does.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -255,41 +326,39 @@
 fadd1(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
+################################################################################
 # Pack Everything into One Library
-# --------------------------------
-# In the above example, we store the device and host code separately.
-# TVM also supports export everything as one shared library.
-# Under the hood, we pack the device modules into binary blobs and link
-# them together with the host code.
-# Currently we support packing of Metal, OpenCL and CUDA modules.
-#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# In the above example, we store the device and host code separately. TVM also
+# supports export everything as one shared library. Under the hood, we pack
+# the device modules into binary blobs and link them together with the host
+# code. Currently we support packing of Metal, OpenCL and CUDA modules.
+
 fadd.export_library(temp.relpath("myadd_pack.so"))
 fadd2 = tvm.runtime.load_module(temp.relpath("myadd_pack.so"))
 fadd2(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
+################################################################################
 # .. note:: Runtime API and Thread-Safety
 #
-#   The compiled modules of TVM do not depend on the TVM compiler.
-#   Instead, they only depend on a minimum runtime library.
-#   The TVM runtime library wraps the device drivers and provides
-#   thread-safe and device agnostic calls into the compiled functions.
-#
-#   This means that you can call the compiled TVM functions from any thread,
-#   on any GPUs.
+#   The compiled modules of TVM do not depend on the TVM compiler. Instead,
+#   they only depend on a minimum runtime library. The TVM runtime library
+#   wraps the device drivers and provides thread-safe and device agnostic calls
+#   into the compiled functions.
 #
+#   This means that you can call the compiled TVM functions from any thread, on
+#   any GPUs.

Review comment:
       Specify that the code has to be compiled for a specific gpu though.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -255,41 +326,39 @@
 fadd1(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
+################################################################################
 # Pack Everything into One Library
-# --------------------------------
-# In the above example, we store the device and host code separately.
-# TVM also supports export everything as one shared library.
-# Under the hood, we pack the device modules into binary blobs and link
-# them together with the host code.
-# Currently we support packing of Metal, OpenCL and CUDA modules.
-#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# In the above example, we store the device and host code separately. TVM also
+# supports export everything as one shared library. Under the hood, we pack
+# the device modules into binary blobs and link them together with the host
+# code. Currently we support packing of Metal, OpenCL and CUDA modules.
+
 fadd.export_library(temp.relpath("myadd_pack.so"))
 fadd2 = tvm.runtime.load_module(temp.relpath("myadd_pack.so"))
 fadd2(a, b, c)
 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
+################################################################################
 # .. note:: Runtime API and Thread-Safety
 #
-#   The compiled modules of TVM do not depend on the TVM compiler.
-#   Instead, they only depend on a minimum runtime library.
-#   The TVM runtime library wraps the device drivers and provides
-#   thread-safe and device agnostic calls into the compiled functions.
-#
-#   This means that you can call the compiled TVM functions from any thread,
-#   on any GPUs.
+#   The compiled modules of TVM do not depend on the TVM compiler. Instead,
+#   they only depend on a minimum runtime library. The TVM runtime library
+#   wraps the device drivers and provides thread-safe and device agnostic calls
+#   into the compiled functions.
 #
+#   This means that you can call the compiled TVM functions from any thread, on
+#   any GPUs.
 
-######################################################################
+################################################################################
 # Generate OpenCL Code
 # --------------------
-# TVM provides code generation features into multiple backends,
-# we can also generate OpenCL code or LLVM code that runs on CPU backends.
+# TVM provides code generation features into multiple backends, we can also

Review comment:
       ```suggestion
   # TVM provides code generation features into multiple backends. We can also
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit

Review comment:
       ```suggestion
   #    computation and hot-spot memory access can be accelerated by a high cache hit
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit
+#    rate. This requires us to transform the origin memory access pattern to the
+#    pattern fits the cache policy.

Review comment:
       ```suggestion
   #    rate. This requires us to transform the origin memory access pattern to a pattern that fits the cache policy.
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit
+#    rate. This requires us to transform the origin memory access pattern to the
+#    pattern fits the cache policy.
+# 2. SIMD (Single instruction multi-data), also known as the vector processing
+#    unit. Every time, a small batch of data, rather than a single grid, will be
+#    processed. This requires us to transform the data access pattern in the loop

Review comment:
       ```suggestion
   #    unit. On each cycle instead of processing a single value, SIMD can process a small batch of data.
   #    This requires us to transform the data access pattern in the loop
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit
+#    rate. This requires us to transform the origin memory access pattern to the
+#    pattern fits the cache policy.
+# 2. SIMD (Single instruction multi-data), also known as the vector processing
+#    unit. Every time, a small batch of data, rather than a single grid, will be
+#    processed. This requires us to transform the data access pattern in the loop
+#    body in uniform pattern so that the LLVM backend can lower it to SIMD.
+#
+# The techniques used in this tutorial are a subset of tricks mentioned in this
+# `repository <https://github.com/flame/how-to-optimize-gemm>`_. Some of them
+# have been applied by TVM abstraction automatically, but some of them cannot
+# be simply applied due to TVM constraints.
+#
+# All the experiment results mentioned below, are executed on 2015 15" MacBook
+# equipped with Intel i7-4770HQ CPU. The cache line size should  be 64 bytes for

Review comment:
       ```suggestion
   # All the experiment results mentioned below are executed on 2015 15" MacBook
   # equipped with Intel i7-4770HQ CPU. The cache line size should be 64 bytes for
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit
+#    rate. This requires us to transform the origin memory access pattern to the
+#    pattern fits the cache policy.
+# 2. SIMD (Single instruction multi-data), also known as the vector processing
+#    unit. Every time, a small batch of data, rather than a single grid, will be
+#    processed. This requires us to transform the data access pattern in the loop
+#    body in uniform pattern so that the LLVM backend can lower it to SIMD.
+#
+# The techniques used in this tutorial are a subset of tricks mentioned in this
+# `repository <https://github.com/flame/how-to-optimize-gemm>`_. Some of them
+# have been applied by TVM abstraction automatically, but some of them cannot
+# be simply applied due to TVM constraints.
+#
+# All the experiment results mentioned below, are executed on 2015 15" MacBook
+# equipped with Intel i7-4770HQ CPU. The cache line size should  be 64 bytes for
+# all the x86 CPUs.
+
+################################################################################
+# Preparation and Performance Baseline
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# We begin by collecting performance data on the `numpy` implementation of
+# matrix multiplication.
+
+import tvm
+import tvm.testing
+from tvm import te
+import numpy
+import timeit
+
+# The size of the matrix
+# (M, K) x (K, N)
+# You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.
+M = 1024
+K = 1024
+N = 1024
+
+# The default tensor data type in tvm
+dtype = "float32"
+
+# using Intel AVX2 (Advanced Vector Extensions) ISA for SIMD
+# To get the best performance, please change the following line
+# to llvm -mcpu=core-avx2, or specific type of CPU you use
+target = "llvm"
+ctx = tvm.context(target, 0)
+
+# Random generated tensor for testing
+a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), ctx)
+b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), ctx)
+
+# Repeatedly perform a matrix multiplication to get a performance baseline
+# for the default numpy implementation
+np_repeat = 100
+np_runing_time = timeit.timeit(
+    setup="import numpy\n"
+    "M = " + str(M) + "\n"
+    "K = " + str(K) + "\n"
+    "N = " + str(N) + "\n"
+    'dtype = "float32"\n'
+    "a = numpy.random.rand(M, K).astype(dtype)\n"
+    "b = numpy.random.rand(K, N).astype(dtype)\n",
+    stmt="answer = numpy.dot(a, b)",
+    number=np_repeat,
+)
+print("Numpy running time: %f" % (np_runing_time / np_repeat))
+
+answer = numpy.dot(a.asnumpy(), b.asnumpy())
+
+################################################################################
+# Now, write a basic matrix multiplication using TVM TE and verify that it
+# produces the same results as the numpy implementation. We also write a
+# function that will help us measure the performance of the schedule
+# optimizations.
+
+# TVM Matrix Multiplication using TE
+k = te.reduce_axis((0, K), "k")
+A = te.placeholder((M, K), name="A")
+B = te.placeholder((K, N), name="B")
+C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
+
+# Default schedule
+s = te.create_schedule(C.op)
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
+assert func

Review comment:
       I'd remove the asserts.

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit
+#    rate. This requires us to transform the origin memory access pattern to the
+#    pattern fits the cache policy.
+# 2. SIMD (Single instruction multi-data), also known as the vector processing
+#    unit. Every time, a small batch of data, rather than a single grid, will be
+#    processed. This requires us to transform the data access pattern in the loop
+#    body in uniform pattern so that the LLVM backend can lower it to SIMD.
+#
+# The techniques used in this tutorial are a subset of tricks mentioned in this
+# `repository <https://github.com/flame/how-to-optimize-gemm>`_. Some of them
+# have been applied by TVM abstraction automatically, but some of them cannot
+# be simply applied due to TVM constraints.

Review comment:
       ```suggestion
   # be automatically applied due to TVM constraints.
   ```

##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +371,437 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU programming.
+#   - compute_at: by default, TVM will compute tensors at the root by default. comput_at specifies
+#     that one tensor should be computed at the first axis of computation for another operator.
+#   - compute_inline: when marked inline, a computation will be expanded then inserted into the
+#     address where the tensor is required.
+#   - compute_root: moves a computation to the root stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 line of python code from TVM we can demonstrate up to 18x speedup on
+# a common matrix multiplication operation.
+#
+# **There are two important optimizations on intense computation applications
+# executed on CPU:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated from high cache hit
+#    rate. This requires us to transform the origin memory access pattern to the
+#    pattern fits the cache policy.
+# 2. SIMD (Single instruction multi-data), also known as the vector processing
+#    unit. Every time, a small batch of data, rather than a single grid, will be
+#    processed. This requires us to transform the data access pattern in the loop
+#    body in uniform pattern so that the LLVM backend can lower it to SIMD.
+#
+# The techniques used in this tutorial are a subset of tricks mentioned in this
+# `repository <https://github.com/flame/how-to-optimize-gemm>`_. Some of them
+# have been applied by TVM abstraction automatically, but some of them cannot
+# be simply applied due to TVM constraints.
+#
+# All the experiment results mentioned below, are executed on 2015 15" MacBook
+# equipped with Intel i7-4770HQ CPU. The cache line size should  be 64 bytes for
+# all the x86 CPUs.
+
+################################################################################
+# Preparation and Performance Baseline
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# We begin by collecting performance data on the `numpy` implementation of
+# matrix multiplication.
+
+import tvm
+import tvm.testing
+from tvm import te
+import numpy
+import timeit
+
+# The size of the matrix
+# (M, K) x (K, N)
+# You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.
+M = 1024
+K = 1024
+N = 1024
+
+# The default tensor data type in tvm
+dtype = "float32"
+
+# using Intel AVX2 (Advanced Vector Extensions) ISA for SIMD
+# To get the best performance, please change the following line
+# to llvm -mcpu=core-avx2, or specific type of CPU you use

Review comment:
       Let users know how they can find what cpu they have.

##########
File path: tutorials/get_started/tune_matmul_x86.py
##########
@@ -15,24 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 """
-Auto-scheduling Matrix Multiplication for CPU
-=============================================
+Optimizing Operators with Auto-scheduling
+=========================================
 **Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, \
             `Chengfan Jia <https://github.com/jcf94/>`_
 
-This is a tutorial on how to use the auto-scheduler for CPUs.
+In this tutorial, we will show how TVM's Auto Scheduling feature can find
+optimal schedules, without the need for writing a custom template.

Review comment:
       ```suggestion
   optimal schedules without the need for writing a custom template.
   ```

##########
File path: tutorials/get_started/tvmc_command_line_driver.py
##########
@@ -15,30 +15,30 @@
 # specific language governing permissions and limitations
 # under the License.
 """
-Getting Started with TVM command line driver - TVMC
-===================================================
+Compiling and Optimizing a Model with TVMC
+==========================================
 **Authors**:
 `Leandro Nunes <https://github.com/leandron>`_,
-`Matthew Barrett <https://github.com/mbaret>`_
+`Matthew Barrett <https://github.com/mbaret>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
 
-This tutorial is an introduction to working with TVMC, the TVM command
-line driver. TVMC is a tool that exposes TVM features such as
-auto-tuning, compiling, profiling and execution of models, via a
-command line interface.
+In this section, we will work with TVMC, the TVM command line driver. TVMC is a tool that exposes TVM
+features such as auto-tuning, compiling, profiling and execution of models, all through a command line

Review comment:
       ```suggestion
   features such as auto-tuning, compiling, profiling and execution of models through a command line
   ```

##########
File path: tutorials/get_started/tune_matmul_x86.py
##########
@@ -67,12 +72,17 @@ def matmul_add(N, L, M, dtype):
     return [A, B, C, out]
 
 
-######################################################################
+################################################################################
 # Create the search task
-# ^^^^^^^^^^^^^^^^^^^^^^
-# We then create a search task with N=L=M=1024 and dtype="float32"
-# If your machine supports avx instructions, you can
+# ----------------------
+# With the function defined, we can now create the task for the auth_scheduler

Review comment:
       ```suggestion
   # With the function defined, we can now create the task for the auto_scheduler
   ```

##########
File path: tutorials/get_started/tvmc_command_line_driver.py
##########
@@ -97,114 +97,134 @@
 
 
 ######################################################################
-# Compiling the model
-# -------------------
+# Compiling an ONNX Model to the TVM Runtime
+# ------------------------------------------
 #
-# The next step once we've downloaded ResNet-50, is to compile it,
-# To accomplish that, we are going to use ``tvmc compile``. The
-# output we get from the compilation process is a TAR package,
-# that can be used to run our model on the target device.
+# Once we've downloaded the ResNet-50 model, the next step is to compile it. To accomplish that, we are
+# going to use ``tvmc compile``. The output we get from the compilation process is a TAR package of the model
+# compiled to a dynamic library for our target platform. We can run that model on our target device using the
+#  TVM runtime.
 #
 # .. code-block:: bash
 #
 #   tvmc compile \
-#     --target "llvm" \
-#     --output compiled_module.tar \
-#     resnet50-v2-7.onnx
+#   --target "llvm" \
+#   --output resnet50-v2-7-tvm.tar \
+#   resnet50-v2-7.onnx
+#
+# Let's take a look at the files that ``tvmc compile`` creates:
+#
+# .. code-block:: bash
 #
-# Once compilation finishes, the output ``compiled_module.tar`` will be created. This
-# can be directly loaded by your application and run via the TVM runtime APIs.
+# 	mkdir model
+# 	tar -xvf resnet50-v2-7-tvm.tar -C model
+# 	ls model
+#
+# You will see three files listed.
+#
+# * ``mod.so`` is the model, represented as a C++ library, that can be loaded by the TVM runtime.
+# * ``mod.json`` is a text representation of the TVM Relay computation graph.
+# * ``mod.params`` is a file containing the parameters for the pre-trained model.
+#
+# This model can be directly loaded by your application and run via the TVM runtime APIs.
 #
 
 
 ######################################################################
-# .. note:: Defining the correct target
+# .. note:: Defining the Correct Target
 #
 #   Specifying the correct target (option ``--target``) can have a huge
 #   impact on the performance of the compiled module, as it can take
 #   advantage of hardware features available on the target. For more
 #   information, please refer to `Auto-tuning a convolutional network
 #   for x86 CPU <https://tvm.apache.org/docs/tutorials/autotvm/tune_relay_x86.html#define-network>`_.
+#   We recommend identifying which CPU you are running, along with optional features,
+#   and set the target appropriately.
 #
 
-
 ######################################################################
+# Running the TVM IR Model with TVMC
+# ----------------------------------
+#
+# Now that we've compiled the model, we can use the TVM runtime to make predictions with it.
+# TVMC has the TVM runtime build in to it, allowing you to run compiled TVM models. To use TVMC to run the

Review comment:
       ```suggestion
   # TVMC has the TVM runtime built in to it, allowing you to run compiled TVM models. To use TVMC to run the
   ```

##########
File path: tutorials/get_started/tvmc_command_line_driver.py
##########
@@ -15,30 +15,30 @@
 # specific language governing permissions and limitations
 # under the License.
 """
-Getting Started with TVM command line driver - TVMC
-===================================================
+Compiling and Optimizing a Model with TVMC
+==========================================
 **Authors**:
 `Leandro Nunes <https://github.com/leandron>`_,
-`Matthew Barrett <https://github.com/mbaret>`_
+`Matthew Barrett <https://github.com/mbaret>`_,
+`Chris Hoge <https://github.com/hogepodge>`_
 
-This tutorial is an introduction to working with TVMC, the TVM command
-line driver. TVMC is a tool that exposes TVM features such as
-auto-tuning, compiling, profiling and execution of models, via a
-command line interface.
+In this section, we will work with TVMC, the TVM command line driver. TVMC is a tool that exposes TVM
+features such as auto-tuning, compiling, profiling and execution of models, all through a command line
+interface.
 
-In this tutorial we are going to use TVMC to compile, run and tune a
-ResNet-50 on a x86 CPU.
+Upon completion of this section, we will have used TVMC to accomplish the following tasks:
 
-We are going to start by downloading ResNet 50 V2. Then, we are going
-to use TVMC to compile this model into a TVM module, and use the
-compiled module to generate predictions. Finally, we are going to experiment
-with the auto-tuning options, that can be used to help the compiler to
-improve network performance.
+* Compile a pre-trained ResNet 50 v2 model for the TVM runtime.
+* Run a real image through the compiled model, and interpret the output and model performance.
+* Tune the model that model on a CPU using TVM.

Review comment:
       ```suggestion
   * Tune the model on a CPU using TVM.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org