You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/03/08 04:46:21 UTC

[incubator-tvm] branch master updated: Add BN support with run-time mean and variance calculation (#4990)

This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new ba47786  Add BN support with run-time mean and variance calculation (#4990)
ba47786 is described below

commit ba477865c0e52778dcb0a07ba84cbe8488cd0719
Author: lfengad <lf...@connect.ust.hk>
AuthorDate: Sun Mar 8 12:46:12 2020 +0800

    Add BN support with run-time mean and variance calculation (#4990)
---
 python/tvm/relay/frontend/tensorflow.py            | 10 ++-
 .../python/frontend/tensorflow/test_bn_dynamic.py  | 72 ++++++++++++++++++++++
 2 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 24164a3..f1cb815 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -877,6 +877,7 @@ def _fused_batch_norm():
     def _impl(inputs, attr, params):
         # Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
         # Relay:       (data, gamma, beta, moving_mean, moving_varience)
+        assert len(inputs) == 5
         axis = 3
         need_cast = False
 
@@ -887,7 +888,14 @@ def _fused_batch_norm():
         if 'U' in attr:
             need_cast = True
             inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name)
-
+        # Check if mean and variance are empty
+        # If so, replace them with Mean and Variance Ops
+        # For run-time calculation
+        moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape]
+        moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape]
+        if (moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0):
+            inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True)
+            inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True)
         out = AttrCvt(op_name='batch_norm',
                       transforms={'scale_after_normalization':'scale',
                                   'variance_epsilon':'epsilon'},
diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py
new file mode 100644
index 0000000..4be838e
--- /dev/null
+++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+BatchNorm without given mean and variance given testcases
+====================
+This is a test script to test fused_batch_norm operators
+in TensorFlow frontend when mean and variance are not given.
+"""
+import tvm
+import numpy as np
+import tensorflow as tf
+from tvm import relay
+from tensorflow.python.framework import graph_util
+
+def verify_fused_batch_norm(shape):
+    g = tf.Graph()
+    with g.as_default():
+        input_tensor = tf.placeholder(tf.float32, shape=shape, name='input')
+        alpha = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='alpha')
+        beta = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='beta')
+        bn = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name='bn')
+        out = tf.identity(bn[0], name='output')
+    data = np.random.rand(*shape)
+    with tf.Session(graph=out.graph) as sess:
+        sess.run([tf.global_variables_initializer()])
+        tf_out = sess.run(out, feed_dict={input_tensor:data})
+        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
+
+    for device in ["llvm"]:
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            continue
+        mod, params = relay.frontend.from_tensorflow(constant_graph,
+                                                     outputs=['output'])
+        with relay.build_config(opt_level=3):
+            graph, lib, params = relay.build(mod,
+                                             target=device,
+                                             params=params)
+        from tvm.contrib import graph_runtime
+        m = graph_runtime.create(graph, lib, ctx)
+        m.set_input(**params)
+        m.set_input('input', data)
+        m.run()
+        tvm_out = m.get_output(0)
+        tvm.testing.assert_allclose(tvm_out.asnumpy(), tf_out.astype(tvm_out.dtype),
+                                    atol=1e-3, rtol=1e-3)
+
+def test_fused_batch_norm():
+    verify_fused_batch_norm(shape=(1, 12, 12, 32))
+    verify_fused_batch_norm(shape=(1, 24, 24, 64))
+    verify_fused_batch_norm(shape=(1, 64, 64, 128))
+    verify_fused_batch_norm(shape=(8, 12, 12, 32))
+    verify_fused_batch_norm(shape=(16, 12, 12, 32))
+    verify_fused_batch_norm(shape=(32, 12, 12, 32))
+
+if __name__ == "__main__":
+    test_fused_batch_norm()