You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/30 10:38:48 UTC

[GitHub] [incubator-tvm] nolanliou opened a new issue #6801: [BUG] Cannot convert TF model when contains `argwhere` and `strided_slice` op

nolanliou opened a new issue #6801:
URL: https://github.com/apache/incubator-tvm/issues/6801


   # Description
   `strided_slice` op not work when there is an `argwhere` input pipeline, because the output shape of `argwhere` is dynamic. Is there a better solution to fix that?
   
   # Error information
   ```
   TVMError: Check failed: ObjectTypeChecker<TObjectRef>: :Check(ptr): Expect Array[IntImm] but get Array
   ```
   # Env
   * TVM: master
   
   # Minimal code to reproduce.
   ```
   import numpy as np
   
   try:
       import tensorflow.compat.v1 as tf
   except ImportError:
       import tensorflow as tf
   from tensorflow.python.ops import variables
   import tvm
   from tvm import relay
   import tvm.relay.testing.tf as tf_testing
   from tvm.runtime.vm import VirtualMachine
   import tvm.testing
   
   def convert_to_list(x):
       if not isinstance(x, list):
           x = [x]
       return x
   
   def vmobj_to_list(o):
       if isinstance(o, tvm.nd.NDArray):
           return [o.asnumpy()]
       elif isinstance(o, tvm.runtime.container.ADT):
           result = []
           for f in o:
               result.extend(vmobj_to_list(f))
           return result
       elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
           if o.constructor.name_hint == "Cons":
               tl = vmobj_to_list(o.fields[1])
               hd = vmobj_to_list(o.fields[0])
               hd.extend(tl)
               return hd
           elif o.constructor.name_hint == "Nil":
               return []
           elif "tensor_nil" in o.constructor.name_hint:
               return [0]
           elif "tensor" in o.constructor.name_hint:
               return [o.fields[0].asnumpy()]
           else:
               raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
       else:
           raise RuntimeError("Unknown object type: %s" % type(o))
   
   
   def run_tvm_graph(
           graph_def,
           input_data,
           input_node,
           num_output=1,
           target="llvm",
           out_names=None,
           opt_level=3,
           mode="graph_runtime",
           cuda_layout="NCHW",
           layout=None,
           disabled_pass=None,
           ignore_in_shape=False,
           serialize=False,
   ):
       """ Generic function to compile on relay and execute on tvm """
       input_data = convert_to_list(input_data)
       input_node = convert_to_list(input_node)
       if target == "cuda":
           layout = cuda_layout
       target_host = None
       if ignore_in_shape:
           shape_dict = None
       else:
           shape_dict = {
               e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)
           }
       mod, params = relay.frontend.from_tensorflow(
           graph_def, layout=layout, shape=shape_dict, outputs=out_names
       )
       ctx = tvm.context(target, 0)
       if mode == "vm":
           with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
               vm_exec = relay.vm.compile(mod, target="llvm", params=params)
           if serialize:
               code, lib = vm_exec.save()
               vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
           vm = VirtualMachine(vm_exec, tvm.cpu())
           inputs = {}
           for e, i in zip(input_node, input_data):
               inputs[e] = tvm.nd.array(i)
           result = vm.invoke("main", **inputs)
           return vmobj_to_list(result)
       else:
           with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
               graph, lib, params = relay.build(mod, target, target_host, params)
           from tvm.contrib import graph_runtime
   
           m = graph_runtime.create(graph, lib, ctx)
           # set inputs
           for e, i in zip(input_node, input_data):
               m.set_input(e, tvm.nd.array(i))
   
           m.set_input(**params)
           # execute
           m.run()
           # get outputs
           assert out_names is None or num_output == len(
               out_names
           ), "out_names: {} num_output: {}".format(out_names, num_output)
           tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
           return tvm_output_list
   
   
   def run_tf_graph(sess, input_data, input_node, output_node):
       """ Generic function to execute tensorflow """
       input_data = convert_to_list(input_data)
       input_node = convert_to_list(input_node)
       output_node = convert_to_list(output_node)
   
       tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
   
       input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
   
       output_data = sess.run(tensor, input_dict)
       return output_data
   
   def compare_tf_with_tvm(
           in_data,
           in_name,
           out_name,
           init_global_variables=False,
           no_gpu=False,
           opt_level=3,
           mode="graph_runtime",
           cuda_layout="NCHW",
   ):
       """Generic function to generate and compare tensorflow and TVM output"""
   
       def name_without_num(name):
           return name.split(":")[0] if ":" in name else name
   
       out_name = convert_to_list(out_name)
       out_node = [name_without_num(name) for name in out_name]
   
       in_data = convert_to_list(in_data)
       in_name = convert_to_list(in_name)
       in_node = [name_without_num(name) for name in in_name]
       with tf.Session() as sess:
           if init_global_variables:
               sess.run(variables.global_variables_initializer())
           final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
   
           tf_output = run_tf_graph(sess, in_data, in_name, out_name)
   
           for device in ["llvm"]:
               ctx = tvm.context(device, 0)
               if not tvm.testing.device_enabled(device):
                   print("Skip because %s is not enabled" % device)
                   continue
               if no_gpu and device == "cuda":
                   continue
   
               tvm_output = run_tvm_graph(
                   final_graph_def,
                   in_data,
                   in_node,
                   target=device,
                   out_names=out_name,
                   num_output=len(out_name),
                   opt_level=opt_level,
                   mode=mode,
                   cuda_layout=cuda_layout,
               )
               # since the names from tensorflow and relay runs are not exactly same,
               # first len(tf_output) will be compared
               for i in range(len(tf_output)):
                   tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
   
           sess.close()
   
   def test_where(
           ip_shape,
           dtype):
       """ One iteration of a Stridedslice """
   
       tf.reset_default_graph()
       with tf.Graph().as_default():
           in_data = tf.placeholder(dtype, ip_shape, name="in_data")
           weight = tf.ones((10, 20))
           mask = tf.squeeze(tf.where(in_data), axis=1)
           data = tf.gather(weight, mask)
           data = tf.reshape(data, shape=(-1, 2, 10))
           tf.strided_slice(
               data,
               begin=(1, 0, 0),
               end=(0, 1, 0),
               strides=(1, 1, 1),
               begin_mask=4,
               end_mask=5,
               new_axis_mask=0,
               shrink_axis_mask=2,
               ellipsis_mask=0,
               name="output",
           )
   
           np_data = (np.random.uniform(size=ip_shape) * 10).astype(dtype)
           np_data[0:3] = 0
           print(np_data)
   
           compare_tf_with_tvm(np_data, "in_data:0", "output:0", mode='vm')
   
   
   if __name__ == "__main__":
       test_where((6), "int32")
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org