You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/30 01:24:18 UTC
[GitHub] [tvm] lazycal opened a new pull request #10106: [onnx] fix onnx where broadcast
lazycal opened a new pull request #10106:
URL: https://github.com/apache/tvm/pull/10106
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.
Previous logic for importing `where` only considers the shape of the tensors with the largest rank. This is incorrect, for instance when cond, x, and y are of the following shapes: [3,1], [2], [2]. The resulting shape should be [3,2], but original logic gives [3,1]. Below is the code snippet to reproduce.
```python
from tvm import relay
import numpy as np
import onnx
from onnx import TensorProto, helper, mapping, numpy_helper
def get_onnx_model(condition, x, y):
outdata = np.where(condition, x, y)
dtype = TensorProto.FLOAT
where_inputs = ["cond", "x", "y"]
node = helper.make_node("Where", inputs=where_inputs, outputs=["out"])
node_list = [node]
graph = helper.make_graph(
node_list,
"where_test",
inputs=[
helper.make_tensor_value_info(
"cond", TensorProto.BOOL, list(condition.shape)),
helper.make_tensor_value_info("x", dtype, list(x.shape)),
helper.make_tensor_value_info("y", dtype, list(y.shape)),
],
outputs=[helper.make_tensor_value_info(
"out", dtype, list(outdata.shape))],
)
model = helper.make_model(graph, producer_name="where_test")
return model
def main():
condition = np.random.uniform(size=(3, 1)) < 0.5
x = np.random.uniform(size=(2,)).astype(np.float32)
y = np.random.uniform(size=(2,)).astype(np.float32)
model = get_onnx_model(condition, x, y)
mod, params = relay.frontend.from_onnx(model, freeze_params=True)
res = relay.build_module.create_executor('graph', mod).evaluate()(
**{'cond': condition, 'x': x, 'y': y})
assert np.allclose(res.asnumpy(), np.where(
condition, x, y), rtol=0, atol=0)
main()
```
This PR simply delegates the broadcast logic to `relay.where`, instead of handling during import.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] AndrewZhaoLuo merged pull request #10106: [onnx] fix onnx where broadcast
Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo merged pull request #10106:
URL: https://github.com/apache/tvm/pull/10106
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] AndrewZhaoLuo commented on pull request #10106: [onnx] fix onnx where broadcast
Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #10106:
URL: https://github.com/apache/tvm/pull/10106#issuecomment-1026296351
@lazycal yeah looks like a spurious error. You need to jostle ci by pushing an empty commit. E.g. `git commit -m 'jostle ci' --allow-empty` and `git push`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] lazycal commented on pull request #10106: [onnx] fix onnx where broadcast
Posted by GitBox <gi...@apache.org>.
lazycal commented on pull request #10106:
URL: https://github.com/apache/tvm/pull/10106#issuecomment-1025994297
Not sure why but the test succeeded in my local computer, and the failed test is autotvm which does not seem related. Maybe it's stability issue of the CI?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] lazycal commented on pull request #10106: [onnx] fix onnx where broadcast
Posted by GitBox <gi...@apache.org>.
lazycal commented on pull request #10106:
URL: https://github.com/apache/tvm/pull/10106#issuecomment-1025045493
@Laurawly Can you take a look?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] AndrewZhaoLuo commented on pull request #10106: [onnx] fix onnx where broadcast
Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #10106:
URL: https://github.com/apache/tvm/pull/10106#issuecomment-1027119594
Once more, looks like you got unlucky again :(
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org