You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org> on 2023/01/24 19:48:52 UTC

[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #13802: [ONNX] Support Bernoulli op on ONNX front-end

AndrewZhaoLuo commented on code in PR #13802:
URL: https://github.com/apache/tvm/pull/13802#discussion_r1085839327


##########
tests/python/frontend/onnx/test_forward.py:
##########
@@ -6663,6 +6663,105 @@ def verify_qlinearsigmoid(a_shape):
     verify_qlinearsigmoid([])
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_random_bernoulli(target, dev):
+    """test_random_bernoulli"""
+
+    def verify_bernoulli_with_ort(
+        shape,
+        in_dtype="float32",
+        out_dtype="int32",
+        seed=None,
+        out_shape=None,
+        target=target,
+        dev=dev,
+        use_vm=False,
+        opset=None,
+        freeze_params=False,
+        rtol=0.1,
+        atol=0.1,
+        opt_level=1,
+        convert_config=None,
+    ):
+        def get_bernoulli_model(shape, in_dtype="float32", out_dtype="int32", seed=None):
+            onnx_itype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+            onnx_otype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_dtype)]
+            node = helper.make_node(
+                "Bernoulli",
+                ["input"],
+                ["output"],
+            )
+            dtype_attr = helper.make_attribute("dtype", onnx_otype)
+            node.attribute.append(dtype_attr)
+            if seed is not None:
+                seed_attr = helper.make_attribute("seed", seed)
+                node.attribute.append(seed_attr)
+
+            graph = helper.make_graph(
+                [node],
+                "random_bernoulli_test",
+                inputs=[helper.make_tensor_value_info("input", onnx_itype, list(shape))],
+                outputs=[helper.make_tensor_value_info("output", onnx_otype, list(shape))],
+            )
+            return helper.make_model(graph, producer_name="random_bernoulli_test")
+
+        inputs = np.random.uniform(size=shape).astype(in_dtype)
+        if seed is None:
+            ort_seed = None
+        else:
+            ort_seed = float(seed)
+        model = get_bernoulli_model(shape, in_dtype, out_dtype, ort_seed)
+        if opset is not None:
+            model.opset_import[0].version = opset
+
+        ort_out = get_onnxruntime_output(model, inputs)
+        if use_vm:
+            tvm_out = get_tvm_output_with_vm(
+                model,
+                inputs,
+                target,
+                dev,
+                opset=opset,
+                freeze_params=freeze_params,
+                convert_config=convert_config,
+            )
+        else:
+            tvm_out = get_tvm_output(
+                model,
+                inputs,
+                target,
+                dev,
+                out_shape,
+                opset=opset,
+                opt_level=opt_level,
+                convert_config=convert_config,
+            )
+
+        if not isinstance(tvm_out, list):
+            tvm_out = [tvm_out]
+        if not isinstance(ort_out, list):
+            ort_out = [ort_out]
+        for tvm_val, ort_val in zip(tvm_out, ort_out):
+            tvm.testing.assert_allclose(ort_val.mean(), tvm_val.mean(), rtol=rtol, atol=atol)
+            tvm.testing.assert_allclose(np.std(ort_val), np.std(tvm_val), rtol=rtol, atol=atol)

Review Comment:
   stdev is meaningless if our output is only 1 or 0, all information is encoded in mean.



##########
tests/python/frontend/onnx/test_forward.py:
##########
@@ -6663,6 +6663,105 @@ def verify_qlinearsigmoid(a_shape):
     verify_qlinearsigmoid([])
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_random_bernoulli(target, dev):
+    """test_random_bernoulli"""
+
+    def verify_bernoulli_with_ort(
+        shape,
+        in_dtype="float32",
+        out_dtype="int32",
+        seed=None,
+        out_shape=None,
+        target=target,
+        dev=dev,
+        use_vm=False,
+        opset=None,
+        freeze_params=False,
+        rtol=0.1,
+        atol=0.1,
+        opt_level=1,
+        convert_config=None,
+    ):
+        def get_bernoulli_model(shape, in_dtype="float32", out_dtype="int32", seed=None):
+            onnx_itype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+            onnx_otype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_dtype)]
+            node = helper.make_node(
+                "Bernoulli",
+                ["input"],
+                ["output"],
+            )
+            dtype_attr = helper.make_attribute("dtype", onnx_otype)
+            node.attribute.append(dtype_attr)
+            if seed is not None:
+                seed_attr = helper.make_attribute("seed", seed)
+                node.attribute.append(seed_attr)
+
+            graph = helper.make_graph(
+                [node],
+                "random_bernoulli_test",
+                inputs=[helper.make_tensor_value_info("input", onnx_itype, list(shape))],
+                outputs=[helper.make_tensor_value_info("output", onnx_otype, list(shape))],
+            )
+            return helper.make_model(graph, producer_name="random_bernoulli_test")
+
+        inputs = np.random.uniform(size=shape).astype(in_dtype)
+        if seed is None:
+            ort_seed = None
+        else:
+            ort_seed = float(seed)
+        model = get_bernoulli_model(shape, in_dtype, out_dtype, ort_seed)
+        if opset is not None:
+            model.opset_import[0].version = opset
+
+        ort_out = get_onnxruntime_output(model, inputs)
+        if use_vm:
+            tvm_out = get_tvm_output_with_vm(
+                model,
+                inputs,
+                target,
+                dev,
+                opset=opset,
+                freeze_params=freeze_params,
+                convert_config=convert_config,
+            )
+        else:
+            tvm_out = get_tvm_output(
+                model,
+                inputs,
+                target,
+                dev,
+                out_shape,
+                opset=opset,
+                opt_level=opt_level,
+                convert_config=convert_config,
+            )
+
+        if not isinstance(tvm_out, list):
+            tvm_out = [tvm_out]
+        if not isinstance(ort_out, list):
+            ort_out = [ort_out]
+        for tvm_val, ort_val in zip(tvm_out, ort_out):
+            tvm.testing.assert_allclose(ort_val.mean(), tvm_val.mean(), rtol=rtol, atol=atol)

Review Comment:
   Hmm I think you need to do two things:
   1. verify all outputs are 0 or 1
   2. verify the mean of tvm_val ~= input probability, no need to compare against ort_val (or at least we gain no additional confidence in result from doing so)
   
   2. also has an issue that it is inherently flaky. We can do student-t test or something but for example if you set p-value to 0.05, 5% of the time the test will fail.
   
   cc @octoJon for consultation



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org