You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/19 23:51:35 UTC

[GitHub] [tvm] trevor-m commented on a change in pull request #8074: [Frontend] [Tensorflow2] Added test infrastructure for TF2 frozen models

trevor-m commented on a change in pull request #8074:
URL: https://github.com/apache/tvm/pull/8074#discussion_r635654606



##########
File path: tests/python/frontend/tensorflow2/test_functional_models.py
##########
@@ -0,0 +1,368 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
+# pylint: disable=import-outside-toplevel, redefined-builtin
+"""TF2 to relay converter test: tests basic examples"""
+
+import tempfile
+import tensorflow as tf
+import numpy as np
+import pytest
+from common import compare_tf_tvm
+from common import run_tf_code
+
+
+class AddOne(tf.Module):

Review comment:
       For all the tests in this file, I think it would be best to encapsulate the tf module class(es) inside the test function. For example:
   
   ```
   def test_expand_dims():
       class ExpandDims(tf.Module):
           def get_input(self):
               return np.ones((1, 30), dtype=np.float32)
   
           @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])
           def func(self, x):
               return tf.expand_dims(x, axis=2)
   
       run_all(ExpandDims)
   ```

##########
File path: tests/python/frontend/tensorflow2/common.py
##########
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
+# pylint: disable=import-outside-toplevel, redefined-builtin
+"""TF2 to relay converter test utilities"""
+
+import tvm
+from tvm import relay
+
+from tvm.runtime.vm import VirtualMachine
+import tvm.contrib.graph_executor as runtime
+from tvm.relay.frontend.tensorflow import from_tensorflow
+import tvm.testing
+
+import tensorflow as tf
+from tensorflow.python.eager.def_function import Function
+
+
+def vmobj_to_list(o):
+    if isinstance(o, tvm.nd.NDArray):
+        out = o.asnumpy().tolist()
+    elif isinstance(o, tvm.runtime.container.ADT):
+        result = []
+        for f in o:
+            result.append(vmobj_to_list(f))
+        out = result
+    else:
+        raise RuntimeError("Unknown object type: %s" % type(o))
+    return out
+
+
+def run_tf_code(func, input_):
+    if type(func) is Function:
+        out = func(input_)
+        if isinstance(out, list):
+            a = [x.numpy() for x in out]
+        else:
+            a = out.numpy()
+    else:
+        a = func(tf.constant(input_))
+        if type(a) is dict:
+            a = [x.numpy() for x in a.values()]
+            if len(a) == 1:
+                a = a[0]
+        elif type(a) is list:
+            a = [x.numpy() for x in a]
+            if len(a) == 1:
+                a = a[0]
+        else:
+            a = a.numpy()
+    return a
+
+
+def compile_graph_runtime(

Review comment:
       Graph Runtime is now Graph Executor, may want to rename these functions

##########
File path: tests/python/frontend/tensorflow2/test_sequential_models.py
##########
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
+# pylint: disable=import-outside-toplevel, redefined-builtin
+"""TF2 to relay converter test: testing models built with tf.keras.Sequential()"""
+
+import tempfile
+import numpy as np
+import pytest
+import tensorflow as tf
+from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
+
+from common import compare_tf_tvm
+from common import run_tf_code
+
+
+def run_sequential_model(model_fn, input_shape):
+    def get_input(shape):
+        _input = np.random.uniform(0, 1, shape).astype(dtype="float32")
+        return _input
+
+    def save_and_reload(_model):
+        with tempfile.TemporaryDirectory() as model_path:
+            tf.saved_model.save(_model, model_path)
+            loaded = tf.saved_model.load(model_path)
+            func = loaded.signatures["serving_default"]
+            frozen_func = convert_variables_to_constants_v2(func)
+        return frozen_func
+
+    def model_graph(model, input_shape):
+        _input = get_input(input_shape)
+        f = save_and_reload(model(input_shape))
+        _output = run_tf_code(f, _input)
+        gdef = f.graph.as_graph_def(add_shapes=True)
+        return gdef, _input, _output
+
+    compare_tf_tvm(*model_graph(model_fn, input_shape), vm=True, output_sig=None)
+
+
+def dense_model(input_shape, num_units=128):

Review comment:
       Same here, I would put these functions inside the test functions which they are used.

##########
File path: tests/python/frontend/tensorflow2/common.py
##########
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
+# pylint: disable=import-outside-toplevel, redefined-builtin
+"""TF2 to relay converter test utilities"""
+
+import tvm
+from tvm import relay
+
+from tvm.runtime.vm import VirtualMachine
+import tvm.contrib.graph_executor as runtime
+from tvm.relay.frontend.tensorflow import from_tensorflow
+import tvm.testing
+
+import tensorflow as tf
+from tensorflow.python.eager.def_function import Function
+
+
+def vmobj_to_list(o):
+    if isinstance(o, tvm.nd.NDArray):
+        out = o.asnumpy().tolist()
+    elif isinstance(o, tvm.runtime.container.ADT):
+        result = []
+        for f in o:
+            result.append(vmobj_to_list(f))
+        out = result
+    else:
+        raise RuntimeError("Unknown object type: %s" % type(o))
+    return out
+
+
+def run_tf_code(func, input_):
+    if type(func) is Function:
+        out = func(input_)
+        if isinstance(out, list):
+            a = [x.numpy() for x in out]
+        else:
+            a = out.numpy()
+    else:
+        a = func(tf.constant(input_))
+        if type(a) is dict:
+            a = [x.numpy() for x in a.values()]
+            if len(a) == 1:
+                a = a[0]
+        elif type(a) is list:
+            a = [x.numpy() for x in a]
+            if len(a) == 1:
+                a = a[0]
+        else:
+            a = a.numpy()
+    return a
+
+
+def compile_graph_runtime(
+    mod, params, target="llvm", target_host="llvm", opt_level=3, output_sig=None
+):
+    with tvm.transform.PassContext(opt_level):
+        lib = relay.build(mod, target=target, target_host=target_host, params=params)
+    return lib
+
+
+def compile_vm(
+    mod, params, target="llvm", target_host="llvm", opt_level=3, disabled_pass=None, output_sig=None
+):
+    with tvm.transform.PassContext(opt_level, disabled_pass=disabled_pass):
+        mod = relay.transform.InferType()(mod)

Review comment:
       I don't think InferType is required here

##########
File path: tests/python/frontend/tensorflow2/common.py
##########
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
+# pylint: disable=import-outside-toplevel, redefined-builtin
+"""TF2 to relay converter test utilities"""
+
+import tvm
+from tvm import relay
+
+from tvm.runtime.vm import VirtualMachine
+import tvm.contrib.graph_executor as runtime
+from tvm.relay.frontend.tensorflow import from_tensorflow
+import tvm.testing
+
+import tensorflow as tf
+from tensorflow.python.eager.def_function import Function
+
+
+def vmobj_to_list(o):
+    if isinstance(o, tvm.nd.NDArray):
+        out = o.asnumpy().tolist()
+    elif isinstance(o, tvm.runtime.container.ADT):
+        result = []
+        for f in o:
+            result.append(vmobj_to_list(f))
+        out = result
+    else:
+        raise RuntimeError("Unknown object type: %s" % type(o))
+    return out
+
+
+def run_tf_code(func, input_):
+    if type(func) is Function:
+        out = func(input_)
+        if isinstance(out, list):
+            a = [x.numpy() for x in out]
+        else:
+            a = out.numpy()
+    else:
+        a = func(tf.constant(input_))
+        if type(a) is dict:
+            a = [x.numpy() for x in a.values()]
+            if len(a) == 1:
+                a = a[0]
+        elif type(a) is list:
+            a = [x.numpy() for x in a]
+            if len(a) == 1:
+                a = a[0]
+        else:
+            a = a.numpy()
+    return a
+
+
+def compile_graph_runtime(
+    mod, params, target="llvm", target_host="llvm", opt_level=3, output_sig=None

Review comment:
       `output_sig` is unused here and in `compile_vm()`, is this intended? What is `output_sig`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org