You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/01 21:49:07 UTC

[GitHub] [incubator-tvm] jwfromm opened a new pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

jwfromm opened a new pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975


   This tutorial demonstrates how to load and run a sparse model from the popular transformers module from [Hugging Face](https://huggingface.co/) (🤗). Very recently a 95% sparse version of BERT was made publicly available however 🤗 was unable to achieve speedups using existing frameworks. Using this script, TVM enables a 2-3X speedup by converting appropriate dense layers to sparse dense layers. I think this will be a useful tutorial for user's interested in sparse networks and may be good PR for TVM as a small collaboration with 🤗.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jwfromm commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
jwfromm commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-653053847


   @merrymercy it's fairly quick, I commented out the run command due to dependencies rather than the runtime. This tutorial requires tensorflow 2.2 (our servers currently use 2.1) and transformers. If we think its worth updating the server build then we can run this for real.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] Wheest commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
Wheest commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-656220438


   Note that the `.ipynb` version of the tutorial doesn't work when running `benchmark()`, since it uses the `__file__` variable in `import_graphdef()`, which is not defined in most notebook environments.  Alternative approach to getting path may be needed.
   
   There are also a lot of dependencies for the tutorial (e.g. transformers, tensorflow) which may not be in a user's environment.  Should an Install dependencies section be added [à la](https://tvm.apache.org/docs/tutorials/autotvm/tune_relay_mobile_gpu.html#install-dependencies)?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-652672949


   cc @antinucleon @junrushao1994 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] binarybana commented on a change in pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
binarybana commented on a change in pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#discussion_r448687396



##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not
+possible to load a single value, instead an entire chunk or tile is read in and
+executed using something like vectorized instructions.
+
+Unstructured sparse weights are those that are pruned only on the value of
+the original weights. They may appear to be scattered randomly throughout
+a tensor rather than in chunks like we'd see in block sparse weights.
+At low sparsities, unstructured pruning techniques are difficult to
+accelerate. However, at high sparsities many blocks of all 0 values
+will naturally appear, making it possible to accelerate.
+
+This tutorial interacts with both structured and unstructured sparsity.
+Hugging Face's PruneBert model is unstructured but 95% sparse, allowing us
+to apply TVM's block sparse optimizations to it, even if not optimally.
+When generating random sparse weights for an unpruned model, we do so structured
+sparsity. A fun exercise is comparing the real speed of PruneBert with the block
+sparse speed using fake weights to see the benefit of structured sparsity.
+"""
+
+###############################################################################
+# Load Required Modules
+# ---------------------
+# Other than TVM, scipy, the latest transformers, and
+# tensorflow 2.2+ are required.
+import os
+import tvm
+import time
+import itertools
+import numpy as np
+import tensorflow as tf
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay import data_dep_optimization as ddo
+from tensorflow.python.framework.convert_to_constants import (
+    convert_variables_to_constants_v2,
+)
+import scipy.sparse as sp
+
+
+###############################################################################
+# Configure Settings
+# ------------------
+# Let's start by defining some parameters that define the type of model
+# and sparsity to run.
+# Args:
+# name (str):
+#   The name of the transformer model to download and run.
+# batch_size (int):
+#   The number of batches in an input.
+# seq_len (int):
+#   The length of each input sequence.
+# target (str):
+#   TVM platform identifier. Although cuda is also supported, it requires
+#   tuning that is outside the scope of this tutorial. Note that best
+#   cpu performance can be achieved by setting -mcpu appropriately for
+#   your specific machine.
+# ctx (context):
+#   Which device to run on. Should be one of tvm.cpu() or tvm.gpu().
+# measure_sparse (bool):
+#   If true, then a sparse variant of the network will be run and
+#   benchmarked.
+# bs_r (int):
+#   The block size of structured sparsity to convert weight tensors
+#   into. Changing this parameter may yield speedups for some platforms.
+# sparsity (float):
+#   For models besides PruneBert (which is 95% sparse), this parameter
+#   determines how sparse the generated weights should be. The higher
+#   the sparsity, the faster the result.
+name = "huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad"
+batch_size = 1
+seq_len = 128
+target = "llvm"
+ctx = tvm.cpu()
+measure_sparse = True
+bs_r = 1
+sparsity = 0.85
+
+
+###############################################################################
+# Download and Convert Transformers Model
+# ---------------------------------------
+# Now we'll grab a model from the transformers module, download it,
+# convert it into a graphdef, and finally convert that graphdef into
+# a relay graph that we can optimize and deploy.
+def load_keras_model(module, name, seq_len, batch_size, report_runtime=True):
+    model = module.from_pretrained(name)
+    dummy_input = tf.keras.Input(shape=[seq_len], batch_size=batch_size, dtype="int32")
+    dummy_out = model(dummy_input)  # Propagate shapes through the keras model.
+    if report_runtime:
+        np_input = np.random.uniform(
+            size=[batch_size, seq_len], low=0, high=seq_len
+        ).astype("int32")
+        start = time.time()
+        repeats = 50
+        for i in range(repeats):
+            np_out = model(np_input)
+        end = time.time()
+        print("Keras Runtime: %f ms." % (1000 * ((end - start) / repeats)))
+    return model
+
+
+def convert_to_graphdef(model, batch_size, seq_len):
+    model_func = tf.function(lambda x: model(x))
+    input_dict = model._saved_model_inputs_spec
+    input_spec = input_dict[list(input_dict.keys())[0]]
+    model_func = model_func.get_concrete_function(
+        tf.TensorSpec([batch_size, seq_len], input_spec.dtype)
+    )
+    frozen_func = convert_variables_to_constants_v2(model_func)
+    return frozen_func.graph.as_graph_def()
+
+
+def download_model(name, batch_size, seq_len):
+    import transformers
+    module = getattr(transformers, "TFBertForSequenceClassification")
+    model = load_keras_model(module, name=name, batch_size=batch_size, seq_len=seq_len)
+    return convert_to_graphdef(model, batch_size, seq_len)
+
+
+###############################################################################
+# Convert to Relay Graph
+# ----------------------
+# We now have all the tooling to get a transformers model in the right format
+# for relay conversion. Let's import it! In the following function we
+# save the imported graph in relay's json format so that we dont have
+# to reimport from tensorflow each time this script is run.
+def import_graphdef(
+    name,
+    batch_size,
+    seq_len,
+    save_relay=True,
+    relay_file="model.json",
+    relay_params="model.params",
+):
+    abs_path = os.path.dirname(os.path.abspath(__file__))
+    shape_dict = {"input_1": (batch_size, seq_len)}
+    relay_file = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_file)).replace(
+        "/", "_"
+    )
+    relay_params = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_params)).replace(
+        "/", "_"
+    )
+    if os.path.exists(os.path.join(abs_path, relay_file)) and os.path.exists(
+        os.path.join(abs_path, relay_params)
+    ):
+        with open(os.path.join(abs_path, relay_file), "r") as fi:
+            mod = tvm.ir.load_json(fi.read())
+        with open(os.path.join(abs_path, relay_params), "rb") as fi:
+            params = relay.load_param_dict(fi.read())
+    else:
+        graph_def = download_model(name, batch_size, seq_len)
+
+        mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
+
+        if save_relay:
+            with open(os.path.join(abs_path, relay_file), "w") as fo:
+                fo.write(tvm.ir.save_json(mod))
+            with open(os.path.join(abs_path, relay_params), "wb") as fo:
+                fo.write(relay.save_param_dict(params))
+
+    return mod, params, shape_dict
+
+
+###############################################################################
+# Run the Dense Graph
+# -------------------
+# Let's run the default version of the imported model. Note that even if
+# the weights are sparse, we won't see any speedup because we are using
+# regular matrix multiplies instead of fast sparse kernels.
+def run_relay_graph(mod, params, shape_dict, target, ctx):
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, target=target, params=params)
+    input_shape = shape_dict["input_1"]
+    dummy_data = np.random.uniform(size=input_shape, low=0, high=input_shape[1]).astype(
+        "int32"
+    )
+
+    m = graph_runtime.create(graph, lib, ctx)
+    m.set_input(0, dummy_data)
+    m.set_input(**params)
+    m.run()
+    tvm_output = m.get_output(0)
+
+    ftimer = m.module.time_evaluator("run", ctx, repeat=5, number=5)
+    prof_res = np.array(ftimer().results) * 1000
+    print(
+        "%-20s %-19s (%s)"
+        % ("Runtime:", "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
+    )
+    return tvm_output
+
+
+def run_dense(mod, params, shape_dict, target, ctx):
+    print("Dense Model Benchmark:")
+    return run_relay_graph(mod, params, shape_dict, target, ctx)
+
+
+###############################################################################
+# Run the Sparse Graph
+# -------------------
+# Next we'll convert the graph into a sparse representation and generate
+# fake sparse weights if needed. Then we'll use the same benchmarking
+# script as dense to see how much faster we go! We apply a few relay passes
+# to the graph to get it leveraging sparsity. First we use
+# `simplify_fc_transpose` to use transposes on the weights of dense layers
+# into the parameters. This makes it easier to convert to matrix multiplies
+# to sparse versions. Next we apply `bsr_dense.convert` to identify all

Review comment:
       Also maybe how it relates to `tvm.relay.transform.DenseToSparse` ?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jwfromm commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
jwfromm commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-652663993


   @masahi, @vinx13, @binarybana can you take a look and let me know what you think?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-653225192


   Thanks everyone, this is merged, will ping the thread again once we have TF 2.2 landed in the CI


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-652786167


   How long does this tutorial take to run?
   If it takes a lot of time, it is better to provide some sample outputs. So can know what are expected.
   If it does not take a lot of time, it is better to let it run on the CI server, so we can make sure the tutorial is always runable.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] u99127 commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
u99127 commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-653161883


   > 
   > 
   > @merrymercy it's fairly quick, I commented out the run command due to dependencies rather than the run time. This tutorial requires tensorflow 2.2 (our servers currently use 2.1) and transformers. If we think its worth updating the server build then we can run this for real.
   
   A +1 for updating and keeping this running out of the box. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy edited a comment on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
merrymercy edited a comment on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-652786167






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy edited a comment on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
merrymercy edited a comment on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-652786167


   How long does this tutorial take to run?
   If it takes a lot of time, it is better to provide some sample outputs, so can know what are expected.
   If it does not take a lot of time, it is better to let it run on the CI server, so we can make sure the tutorial is always runnable.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-652686170


   I liked emojis in the PR:)
   
   How about adding a sample output, with avx2 or 512?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] binarybana commented on a change in pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
binarybana commented on a change in pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#discussion_r448652687



##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's

Review comment:
       Can you link to the model here?

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this

Review comment:
       ```suggestion
   sparsity support to produce real speedups. Although the primary purpose of this
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not
+possible to load a single value, instead an entire chunk or tile is read in and
+executed using something like vectorized instructions.
+
+Unstructured sparse weights are those that are pruned only on the value of
+the original weights. They may appear to be scattered randomly throughout
+a tensor rather than in chunks like we'd see in block sparse weights.
+At low sparsities, unstructured pruning techniques are difficult to
+accelerate. However, at high sparsities many blocks of all 0 values
+will naturally appear, making it possible to accelerate.
+
+This tutorial interacts with both structured and unstructured sparsity.
+Hugging Face's PruneBert model is unstructured but 95% sparse, allowing us
+to apply TVM's block sparse optimizations to it, even if not optimally.
+When generating random sparse weights for an unpruned model, we do so structured
+sparsity. A fun exercise is comparing the real speed of PruneBert with the block
+sparse speed using fake weights to see the benefit of structured sparsity.
+"""
+
+###############################################################################
+# Load Required Modules
+# ---------------------
+# Other than TVM, scipy, the latest transformers, and
+# tensorflow 2.2+ are required.
+import os
+import tvm
+import time
+import itertools
+import numpy as np
+import tensorflow as tf
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay import data_dep_optimization as ddo
+from tensorflow.python.framework.convert_to_constants import (
+    convert_variables_to_constants_v2,
+)
+import scipy.sparse as sp
+
+
+###############################################################################
+# Configure Settings
+# ------------------
+# Let's start by defining some parameters that define the type of model
+# and sparsity to run.
+# Args:
+# name (str):
+#   The name of the transformer model to download and run.
+# batch_size (int):
+#   The number of batches in an input.
+# seq_len (int):
+#   The length of each input sequence.
+# target (str):
+#   TVM platform identifier. Although cuda is also supported, it requires
+#   tuning that is outside the scope of this tutorial. Note that best
+#   cpu performance can be achieved by setting -mcpu appropriately for
+#   your specific machine.
+# ctx (context):
+#   Which device to run on. Should be one of tvm.cpu() or tvm.gpu().
+# measure_sparse (bool):
+#   If true, then a sparse variant of the network will be run and
+#   benchmarked.
+# bs_r (int):
+#   The block size of structured sparsity to convert weight tensors
+#   into. Changing this parameter may yield speedups for some platforms.
+# sparsity (float):
+#   For models besides PruneBert (which is 95% sparse), this parameter
+#   determines how sparse the generated weights should be. The higher
+#   the sparsity, the faster the result.
+name = "huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad"
+batch_size = 1
+seq_len = 128
+target = "llvm"
+ctx = tvm.cpu()
+measure_sparse = True
+bs_r = 1
+sparsity = 0.85
+
+
+###############################################################################
+# Download and Convert Transformers Model
+# ---------------------------------------
+# Now we'll grab a model from the transformers module, download it,
+# convert it into a graphdef, and finally convert that graphdef into

Review comment:
       ```suggestion
   # convert it into a TensorFlow graphdef, in preparation for converting that graphdef into
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which

Review comment:
       ```suggestion
   by replacing weight values with 0s. Although many methods exist for choosing which
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not
+possible to load a single value, instead an entire chunk or tile is read in and
+executed using something like vectorized instructions.
+
+Unstructured sparse weights are those that are pruned only on the value of
+the original weights. They may appear to be scattered randomly throughout
+a tensor rather than in chunks like we'd see in block sparse weights.
+At low sparsities, unstructured pruning techniques are difficult to
+accelerate. However, at high sparsities many blocks of all 0 values
+will naturally appear, making it possible to accelerate.
+
+This tutorial interacts with both structured and unstructured sparsity.
+Hugging Face's PruneBert model is unstructured but 95% sparse, allowing us
+to apply TVM's block sparse optimizations to it, even if not optimally.
+When generating random sparse weights for an unpruned model, we do so structured
+sparsity. A fun exercise is comparing the real speed of PruneBert with the block
+sparse speed using fake weights to see the benefit of structured sparsity.
+"""
+
+###############################################################################
+# Load Required Modules
+# ---------------------
+# Other than TVM, scipy, the latest transformers, and
+# tensorflow 2.2+ are required.
+import os
+import tvm
+import time
+import itertools
+import numpy as np
+import tensorflow as tf
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay import data_dep_optimization as ddo
+from tensorflow.python.framework.convert_to_constants import (
+    convert_variables_to_constants_v2,
+)
+import scipy.sparse as sp
+
+
+###############################################################################
+# Configure Settings
+# ------------------
+# Let's start by defining some parameters that define the type of model
+# and sparsity to run.
+# Args:
+# name (str):
+#   The name of the transformer model to download and run.
+# batch_size (int):
+#   The number of batches in an input.
+# seq_len (int):
+#   The length of each input sequence.
+# target (str):
+#   TVM platform identifier. Although cuda is also supported, it requires
+#   tuning that is outside the scope of this tutorial. Note that best
+#   cpu performance can be achieved by setting -mcpu appropriately for
+#   your specific machine.
+# ctx (context):
+#   Which device to run on. Should be one of tvm.cpu() or tvm.gpu().
+# measure_sparse (bool):
+#   If true, then a sparse variant of the network will be run and
+#   benchmarked.
+# bs_r (int):
+#   The block size of structured sparsity to convert weight tensors
+#   into. Changing this parameter may yield speedups for some platforms.
+# sparsity (float):
+#   For models besides PruneBert (which is 95% sparse), this parameter
+#   determines how sparse the generated weights should be. The higher
+#   the sparsity, the faster the result.
+name = "huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad"
+batch_size = 1
+seq_len = 128
+target = "llvm"
+ctx = tvm.cpu()
+measure_sparse = True
+bs_r = 1
+sparsity = 0.85
+
+
+###############################################################################
+# Download and Convert Transformers Model
+# ---------------------------------------
+# Now we'll grab a model from the transformers module, download it,
+# convert it into a graphdef, and finally convert that graphdef into
+# a relay graph that we can optimize and deploy.
+def load_keras_model(module, name, seq_len, batch_size, report_runtime=True):
+    model = module.from_pretrained(name)
+    dummy_input = tf.keras.Input(shape=[seq_len], batch_size=batch_size, dtype="int32")
+    dummy_out = model(dummy_input)  # Propagate shapes through the keras model.
+    if report_runtime:
+        np_input = np.random.uniform(
+            size=[batch_size, seq_len], low=0, high=seq_len
+        ).astype("int32")
+        start = time.time()
+        repeats = 50
+        for i in range(repeats):
+            np_out = model(np_input)
+        end = time.time()
+        print("Keras Runtime: %f ms." % (1000 * ((end - start) / repeats)))
+    return model
+
+
+def convert_to_graphdef(model, batch_size, seq_len):
+    model_func = tf.function(lambda x: model(x))
+    input_dict = model._saved_model_inputs_spec
+    input_spec = input_dict[list(input_dict.keys())[0]]
+    model_func = model_func.get_concrete_function(
+        tf.TensorSpec([batch_size, seq_len], input_spec.dtype)
+    )
+    frozen_func = convert_variables_to_constants_v2(model_func)
+    return frozen_func.graph.as_graph_def()
+
+
+def download_model(name, batch_size, seq_len):
+    import transformers
+    module = getattr(transformers, "TFBertForSequenceClassification")
+    model = load_keras_model(module, name=name, batch_size=batch_size, seq_len=seq_len)
+    return convert_to_graphdef(model, batch_size, seq_len)
+
+
+###############################################################################
+# Convert to Relay Graph
+# ----------------------
+# We now have all the tooling to get a transformers model in the right format
+# for relay conversion. Let's import it! In the following function we
+# save the imported graph in relay's json format so that we dont have
+# to reimport from tensorflow each time this script is run.
+def import_graphdef(
+    name,
+    batch_size,
+    seq_len,
+    save_relay=True,
+    relay_file="model.json",
+    relay_params="model.params",
+):
+    abs_path = os.path.dirname(os.path.abspath(__file__))
+    shape_dict = {"input_1": (batch_size, seq_len)}
+    relay_file = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_file)).replace(
+        "/", "_"
+    )
+    relay_params = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_params)).replace(
+        "/", "_"
+    )
+    if os.path.exists(os.path.join(abs_path, relay_file)) and os.path.exists(
+        os.path.join(abs_path, relay_params)
+    ):
+        with open(os.path.join(abs_path, relay_file), "r") as fi:
+            mod = tvm.ir.load_json(fi.read())
+        with open(os.path.join(abs_path, relay_params), "rb") as fi:
+            params = relay.load_param_dict(fi.read())
+    else:
+        graph_def = download_model(name, batch_size, seq_len)
+
+        mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
+
+        if save_relay:
+            with open(os.path.join(abs_path, relay_file), "w") as fo:
+                fo.write(tvm.ir.save_json(mod))
+            with open(os.path.join(abs_path, relay_params), "wb") as fo:
+                fo.write(relay.save_param_dict(params))
+
+    return mod, params, shape_dict
+
+
+###############################################################################
+# Run the Dense Graph
+# -------------------
+# Let's run the default version of the imported model. Note that even if
+# the weights are sparse, we won't see any speedup because we are using
+# regular matrix multiplies instead of fast sparse kernels.

Review comment:
       ```suggestion
   # regular dense matrix multiplications on these dense (but mostly zero) tensors instead of sparse aware kernels.
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful

Review comment:
       ```suggestion
   tutorial is to realize speedups on already pruned models, it may also be useful
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,

Review comment:
       ```suggestion
   This tutorial demonstrates how to take any pruned model,
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not

Review comment:
       ```suggestion
   speedups on most hardware available today. This is because when loading memory in most CPUs or GPUs, it doesn't
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not
+possible to load a single value, instead an entire chunk or tile is read in and

Review comment:
       ```suggestion
   save any work to skip reading a single value at a time, instead an entire chunk or tile is read in and
   ```

##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not
+possible to load a single value, instead an entire chunk or tile is read in and
+executed using something like vectorized instructions.
+
+Unstructured sparse weights are those that are pruned only on the value of
+the original weights. They may appear to be scattered randomly throughout
+a tensor rather than in chunks like we'd see in block sparse weights.
+At low sparsities, unstructured pruning techniques are difficult to
+accelerate. However, at high sparsities many blocks of all 0 values
+will naturally appear, making it possible to accelerate.
+
+This tutorial interacts with both structured and unstructured sparsity.
+Hugging Face's PruneBert model is unstructured but 95% sparse, allowing us
+to apply TVM's block sparse optimizations to it, even if not optimally.
+When generating random sparse weights for an unpruned model, we do so structured

Review comment:
       ```suggestion
   When generating random sparse weights for an unpruned model, we do so with structured
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] u99127 edited a comment on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
u99127 edited a comment on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-653161883


   > 
   > 
   > @merrymercy it's fairly quick, I commented out the run command due to dependencies rather than the run time. This tutorial requires tensorflow 2.2 (our servers currently use 2.1) and transformers. If we think its worth updating the server build then we can run this for real.
   
   A +1 for updating TF versions and keeping this running out of the box. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy edited a comment on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
merrymercy edited a comment on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-652786167


   How long does this tutorial take to run?
   If it takes a lot of time, it is better to provide some sample outputs, so readers can know what is expected.
   Otherwise, it is better to let it run on the CI server, so we can get output from the web server and make sure the tutorial is always runnable.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] binarybana commented on a change in pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
binarybana commented on a change in pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#discussion_r448687234



##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not
+possible to load a single value, instead an entire chunk or tile is read in and
+executed using something like vectorized instructions.
+
+Unstructured sparse weights are those that are pruned only on the value of
+the original weights. They may appear to be scattered randomly throughout
+a tensor rather than in chunks like we'd see in block sparse weights.
+At low sparsities, unstructured pruning techniques are difficult to
+accelerate. However, at high sparsities many blocks of all 0 values
+will naturally appear, making it possible to accelerate.
+
+This tutorial interacts with both structured and unstructured sparsity.
+Hugging Face's PruneBert model is unstructured but 95% sparse, allowing us
+to apply TVM's block sparse optimizations to it, even if not optimally.
+When generating random sparse weights for an unpruned model, we do so structured
+sparsity. A fun exercise is comparing the real speed of PruneBert with the block
+sparse speed using fake weights to see the benefit of structured sparsity.
+"""
+
+###############################################################################
+# Load Required Modules
+# ---------------------
+# Other than TVM, scipy, the latest transformers, and
+# tensorflow 2.2+ are required.
+import os
+import tvm
+import time
+import itertools
+import numpy as np
+import tensorflow as tf
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay import data_dep_optimization as ddo
+from tensorflow.python.framework.convert_to_constants import (
+    convert_variables_to_constants_v2,
+)
+import scipy.sparse as sp
+
+
+###############################################################################
+# Configure Settings
+# ------------------
+# Let's start by defining some parameters that define the type of model
+# and sparsity to run.
+# Args:
+# name (str):
+#   The name of the transformer model to download and run.
+# batch_size (int):
+#   The number of batches in an input.
+# seq_len (int):
+#   The length of each input sequence.
+# target (str):
+#   TVM platform identifier. Although cuda is also supported, it requires
+#   tuning that is outside the scope of this tutorial. Note that best
+#   cpu performance can be achieved by setting -mcpu appropriately for
+#   your specific machine.
+# ctx (context):
+#   Which device to run on. Should be one of tvm.cpu() or tvm.gpu().
+# measure_sparse (bool):
+#   If true, then a sparse variant of the network will be run and
+#   benchmarked.
+# bs_r (int):
+#   The block size of structured sparsity to convert weight tensors
+#   into. Changing this parameter may yield speedups for some platforms.
+# sparsity (float):
+#   For models besides PruneBert (which is 95% sparse), this parameter
+#   determines how sparse the generated weights should be. The higher
+#   the sparsity, the faster the result.
+name = "huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad"
+batch_size = 1
+seq_len = 128
+target = "llvm"
+ctx = tvm.cpu()
+measure_sparse = True
+bs_r = 1
+sparsity = 0.85
+
+
+###############################################################################
+# Download and Convert Transformers Model
+# ---------------------------------------
+# Now we'll grab a model from the transformers module, download it,
+# convert it into a graphdef, and finally convert that graphdef into
+# a relay graph that we can optimize and deploy.
+def load_keras_model(module, name, seq_len, batch_size, report_runtime=True):
+    model = module.from_pretrained(name)
+    dummy_input = tf.keras.Input(shape=[seq_len], batch_size=batch_size, dtype="int32")
+    dummy_out = model(dummy_input)  # Propagate shapes through the keras model.
+    if report_runtime:
+        np_input = np.random.uniform(
+            size=[batch_size, seq_len], low=0, high=seq_len
+        ).astype("int32")
+        start = time.time()
+        repeats = 50
+        for i in range(repeats):
+            np_out = model(np_input)
+        end = time.time()
+        print("Keras Runtime: %f ms." % (1000 * ((end - start) / repeats)))
+    return model
+
+
+def convert_to_graphdef(model, batch_size, seq_len):
+    model_func = tf.function(lambda x: model(x))
+    input_dict = model._saved_model_inputs_spec
+    input_spec = input_dict[list(input_dict.keys())[0]]
+    model_func = model_func.get_concrete_function(
+        tf.TensorSpec([batch_size, seq_len], input_spec.dtype)
+    )
+    frozen_func = convert_variables_to_constants_v2(model_func)
+    return frozen_func.graph.as_graph_def()
+
+
+def download_model(name, batch_size, seq_len):
+    import transformers
+    module = getattr(transformers, "TFBertForSequenceClassification")
+    model = load_keras_model(module, name=name, batch_size=batch_size, seq_len=seq_len)
+    return convert_to_graphdef(model, batch_size, seq_len)
+
+
+###############################################################################
+# Convert to Relay Graph
+# ----------------------
+# We now have all the tooling to get a transformers model in the right format
+# for relay conversion. Let's import it! In the following function we
+# save the imported graph in relay's json format so that we dont have
+# to reimport from tensorflow each time this script is run.
+def import_graphdef(
+    name,
+    batch_size,
+    seq_len,
+    save_relay=True,
+    relay_file="model.json",
+    relay_params="model.params",
+):
+    abs_path = os.path.dirname(os.path.abspath(__file__))
+    shape_dict = {"input_1": (batch_size, seq_len)}
+    relay_file = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_file)).replace(
+        "/", "_"
+    )
+    relay_params = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_params)).replace(
+        "/", "_"
+    )
+    if os.path.exists(os.path.join(abs_path, relay_file)) and os.path.exists(
+        os.path.join(abs_path, relay_params)
+    ):
+        with open(os.path.join(abs_path, relay_file), "r") as fi:
+            mod = tvm.ir.load_json(fi.read())
+        with open(os.path.join(abs_path, relay_params), "rb") as fi:
+            params = relay.load_param_dict(fi.read())
+    else:
+        graph_def = download_model(name, batch_size, seq_len)
+
+        mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
+
+        if save_relay:
+            with open(os.path.join(abs_path, relay_file), "w") as fo:
+                fo.write(tvm.ir.save_json(mod))
+            with open(os.path.join(abs_path, relay_params), "wb") as fo:
+                fo.write(relay.save_param_dict(params))
+
+    return mod, params, shape_dict
+
+
+###############################################################################
+# Run the Dense Graph
+# -------------------
+# Let's run the default version of the imported model. Note that even if
+# the weights are sparse, we won't see any speedup because we are using
+# regular matrix multiplies instead of fast sparse kernels.
+def run_relay_graph(mod, params, shape_dict, target, ctx):
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, target=target, params=params)
+    input_shape = shape_dict["input_1"]
+    dummy_data = np.random.uniform(size=input_shape, low=0, high=input_shape[1]).astype(
+        "int32"
+    )
+
+    m = graph_runtime.create(graph, lib, ctx)
+    m.set_input(0, dummy_data)
+    m.set_input(**params)
+    m.run()
+    tvm_output = m.get_output(0)
+
+    ftimer = m.module.time_evaluator("run", ctx, repeat=5, number=5)
+    prof_res = np.array(ftimer().results) * 1000
+    print(
+        "%-20s %-19s (%s)"
+        % ("Runtime:", "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
+    )
+    return tvm_output
+
+
+def run_dense(mod, params, shape_dict, target, ctx):
+    print("Dense Model Benchmark:")
+    return run_relay_graph(mod, params, shape_dict, target, ctx)
+
+
+###############################################################################
+# Run the Sparse Graph
+# -------------------
+# Next we'll convert the graph into a sparse representation and generate
+# fake sparse weights if needed. Then we'll use the same benchmarking
+# script as dense to see how much faster we go! We apply a few relay passes
+# to the graph to get it leveraging sparsity. First we use
+# `simplify_fc_transpose` to use transposes on the weights of dense layers
+# into the parameters. This makes it easier to convert to matrix multiplies
+# to sparse versions. Next we apply `bsr_dense.convert` to identify all

Review comment:
       Because this is the only mention of `bsr_dense.convert` in TVM documentation today, could you spend a sentence or two describing what it does in more detail?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jwfromm edited a comment on pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
jwfromm edited a comment on pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#issuecomment-653053847


   @merrymercy it's fairly quick, I commented out the run command due to dependencies rather than the run time. This tutorial requires tensorflow 2.2 (our servers currently use 2.1) and transformers. If we think its worth updating the server build then we can run this for real.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen merged pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5975: [Tutorial] Demo showing how to run a pruned 🤗 model.

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5975:
URL: https://github.com/apache/incubator-tvm/pull/5975#discussion_r449148996



##########
File path: tutorials/frontend/deploy_sparse.py
##########
@@ -0,0 +1,326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy a Hugging Face Pruned Model on CPU
+=========================================
+**Author**: `Josh Fromm <https://github.com/jwfromm>`_
+
+This tutorial demonstrates how to take a state of the art pruned model,
+in this case PruneBert from Hugging Face, and use TVM to leverage the model's
+sparsity to produce real speedups. Although the primary purpose of this
+tutorial is to show speedups on already pruned models, it may be useful
+to estimate how fast a model would be *if* it were pruned. To this end,
+we also provide a function that takes an unpruned model and replaces its weights
+with random and pruned weights at a specified sparsity. This may be a useful
+feature when trying to decide if a model is worth pruning or not.
+
+Before we get into the code, it's useful to discuss sparsity and pruning
+and dig into the two
+different types of sparsity: **structured** and **unstructured**.
+
+Pruning is a technique primarily used to reduce the parameter size of a model
+by replacing weights with 0s. Although many methods exist for choosing which
+weights should be set to 0, the most straight forward is by picking the 
+weights with the smallest value. Typically, weights are pruned to a desired
+sparsity percentage. For example, a 95% sparse model would have only 5% of
+its weights non-zero. Pruning to very high sparsities often requires
+finetuning or full retraining as it tends to be a lossy approximation.
+Although parameter size benefits are quite easy to obtain from a pruned model
+through simple compression, leveraging sparsity to yield runtime speedups
+is more complicated.
+
+In structured sparsity weights are pruned with the goal of clustering
+pruned weights together. In other words, they are pruned using both their
+value and location. The benefit of bunching up pruned weights is that it allows
+an algorithm such as matrix multiplication to skip entire blocks. It turns out
+that some degree of *block sparsity* is very important to realizing significant
+speedups. This is because when loading memory in most CPUs or GPUs, it's not
+possible to load a single value, instead an entire chunk or tile is read in and
+executed using something like vectorized instructions.
+
+Unstructured sparse weights are those that are pruned only on the value of
+the original weights. They may appear to be scattered randomly throughout
+a tensor rather than in chunks like we'd see in block sparse weights.
+At low sparsities, unstructured pruning techniques are difficult to
+accelerate. However, at high sparsities many blocks of all 0 values
+will naturally appear, making it possible to accelerate.
+
+This tutorial interacts with both structured and unstructured sparsity.
+Hugging Face's PruneBert model is unstructured but 95% sparse, allowing us
+to apply TVM's block sparse optimizations to it, even if not optimally.
+When generating random sparse weights for an unpruned model, we do so structured
+sparsity. A fun exercise is comparing the real speed of PruneBert with the block
+sparse speed using fake weights to see the benefit of structured sparsity.
+"""
+
+###############################################################################
+# Load Required Modules
+# ---------------------
+# Other than TVM, scipy, the latest transformers, and
+# tensorflow 2.2+ are required.
+import os
+import tvm
+import time
+import itertools
+import numpy as np
+import tensorflow as tf
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay import data_dep_optimization as ddo
+from tensorflow.python.framework.convert_to_constants import (
+    convert_variables_to_constants_v2,
+)
+import scipy.sparse as sp
+
+
+###############################################################################
+# Configure Settings
+# ------------------
+# Let's start by defining some parameters that define the type of model
+# and sparsity to run.
+# Args:
+# name (str):
+#   The name of the transformer model to download and run.
+# batch_size (int):
+#   The number of batches in an input.
+# seq_len (int):
+#   The length of each input sequence.
+# target (str):
+#   TVM platform identifier. Although cuda is also supported, it requires
+#   tuning that is outside the scope of this tutorial. Note that best
+#   cpu performance can be achieved by setting -mcpu appropriately for
+#   your specific machine.
+# ctx (context):
+#   Which device to run on. Should be one of tvm.cpu() or tvm.gpu().
+# measure_sparse (bool):
+#   If true, then a sparse variant of the network will be run and
+#   benchmarked.
+# bs_r (int):
+#   The block size of structured sparsity to convert weight tensors
+#   into. Changing this parameter may yield speedups for some platforms.
+# sparsity (float):
+#   For models besides PruneBert (which is 95% sparse), this parameter
+#   determines how sparse the generated weights should be. The higher
+#   the sparsity, the faster the result.
+name = "huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad"
+batch_size = 1
+seq_len = 128
+target = "llvm"
+ctx = tvm.cpu()
+measure_sparse = True
+bs_r = 1
+sparsity = 0.85
+
+
+###############################################################################
+# Download and Convert Transformers Model
+# ---------------------------------------
+# Now we'll grab a model from the transformers module, download it,
+# convert it into a graphdef, and finally convert that graphdef into
+# a relay graph that we can optimize and deploy.
+def load_keras_model(module, name, seq_len, batch_size, report_runtime=True):
+    model = module.from_pretrained(name)
+    dummy_input = tf.keras.Input(shape=[seq_len], batch_size=batch_size, dtype="int32")
+    dummy_out = model(dummy_input)  # Propagate shapes through the keras model.
+    if report_runtime:
+        np_input = np.random.uniform(
+            size=[batch_size, seq_len], low=0, high=seq_len
+        ).astype("int32")
+        start = time.time()
+        repeats = 50
+        for i in range(repeats):
+            np_out = model(np_input)
+        end = time.time()
+        print("Keras Runtime: %f ms." % (1000 * ((end - start) / repeats)))
+    return model
+
+
+def convert_to_graphdef(model, batch_size, seq_len):
+    model_func = tf.function(lambda x: model(x))
+    input_dict = model._saved_model_inputs_spec
+    input_spec = input_dict[list(input_dict.keys())[0]]
+    model_func = model_func.get_concrete_function(
+        tf.TensorSpec([batch_size, seq_len], input_spec.dtype)
+    )
+    frozen_func = convert_variables_to_constants_v2(model_func)
+    return frozen_func.graph.as_graph_def()
+
+
+def download_model(name, batch_size, seq_len):
+    import transformers
+    module = getattr(transformers, "TFBertForSequenceClassification")
+    model = load_keras_model(module, name=name, batch_size=batch_size, seq_len=seq_len)
+    return convert_to_graphdef(model, batch_size, seq_len)
+
+
+###############################################################################
+# Convert to Relay Graph
+# ----------------------
+# We now have all the tooling to get a transformers model in the right format
+# for relay conversion. Let's import it! In the following function we
+# save the imported graph in relay's json format so that we dont have
+# to reimport from tensorflow each time this script is run.
+def import_graphdef(
+    name,
+    batch_size,
+    seq_len,
+    save_relay=True,
+    relay_file="model.json",
+    relay_params="model.params",
+):
+    abs_path = os.path.dirname(os.path.abspath(__file__))
+    shape_dict = {"input_1": (batch_size, seq_len)}
+    relay_file = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_file)).replace(
+        "/", "_"
+    )
+    relay_params = ("%s_%d_%d_%s" % (name, batch_size, seq_len, relay_params)).replace(
+        "/", "_"
+    )
+    if os.path.exists(os.path.join(abs_path, relay_file)) and os.path.exists(
+        os.path.join(abs_path, relay_params)
+    ):
+        with open(os.path.join(abs_path, relay_file), "r") as fi:
+            mod = tvm.ir.load_json(fi.read())
+        with open(os.path.join(abs_path, relay_params), "rb") as fi:
+            params = relay.load_param_dict(fi.read())
+    else:
+        graph_def = download_model(name, batch_size, seq_len)
+
+        mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
+
+        if save_relay:
+            with open(os.path.join(abs_path, relay_file), "w") as fo:
+                fo.write(tvm.ir.save_json(mod))
+            with open(os.path.join(abs_path, relay_params), "wb") as fo:
+                fo.write(relay.save_param_dict(params))
+
+    return mod, params, shape_dict
+
+
+###############################################################################
+# Run the Dense Graph
+# -------------------
+# Let's run the default version of the imported model. Note that even if
+# the weights are sparse, we won't see any speedup because we are using
+# regular matrix multiplies instead of fast sparse kernels.
+def run_relay_graph(mod, params, shape_dict, target, ctx):
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, target=target, params=params)
+    input_shape = shape_dict["input_1"]
+    dummy_data = np.random.uniform(size=input_shape, low=0, high=input_shape[1]).astype(
+        "int32"
+    )
+
+    m = graph_runtime.create(graph, lib, ctx)
+    m.set_input(0, dummy_data)
+    m.set_input(**params)
+    m.run()
+    tvm_output = m.get_output(0)
+
+    ftimer = m.module.time_evaluator("run", ctx, repeat=5, number=5)
+    prof_res = np.array(ftimer().results) * 1000
+    print(
+        "%-20s %-19s (%s)"
+        % ("Runtime:", "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
+    )
+    return tvm_output
+
+
+def run_dense(mod, params, shape_dict, target, ctx):
+    print("Dense Model Benchmark:")
+    return run_relay_graph(mod, params, shape_dict, target, ctx)
+
+
+###############################################################################
+# Run the Sparse Graph
+# -------------------
+# Next we'll convert the graph into a sparse representation and generate
+# fake sparse weights if needed. Then we'll use the same benchmarking
+# script as dense to see how much faster we go! We apply a few relay passes
+# to the graph to get it leveraging sparsity. First we use
+# `simplify_fc_transpose` to use transposes on the weights of dense layers
+# into the parameters. This makes it easier to convert to matrix multiplies
+# to sparse versions. Next we apply `bsr_dense.convert` to identify all

Review comment:
       Added a paragraph with more details.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org