You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/07/16 09:56:21 UTC
[GitHub] [incubator-mxnet] wkcn edited a comment on pull request #18707: [MXNET-1453] Support the intput whose dimension is greater than 6 for Transpose and Rollaxis
wkcn edited a comment on pull request #18707:
URL: https://github.com/apache/incubator-mxnet/pull/18707#issuecomment-659304799
Performance Benchmark:
transpose operator on CPU, axes is generated randomly
ndim | max use (kb) | avg time (ms)
---|---|--
1|12582.9121|1.0786
2|12582.9121|1.0851
3|12582.9121|0.6763
4|12582.9121|1.2172
5|12582.9121|6.4305
6|12582.9121|11.7841
7|12583.3604|65.7184
8|12583.4238|65.2171
9|12583.4883|82.4930
The increase of memory footprint is slight, but the time is intolerable when `axes.ndim() > 6`. I will try to optimize it.
If axes is monotonically increasing (namely [0, 1, 2, 3, ..., ndim - 1]),
ndim | max use (kb) | avg time (ms)
---|---|--
1|12582.9121|1.1492
2|12582.9121|1.1732
3|12582.9121|1.3264
4|12582.9121|1.3896
5|12582.9121|0.9107
6|12582.9121|0.8965
7|12583.3604|0.9028
8|12583.4238|0.9105
9|12583.4883|0.8981
Test Code:
```python
import mxnet as mx
from mxnet import profiler
print(mx)
import numpy as np
from numpy.testing import assert_allclose
import time
import random
seed = 42
np.random.seed(seed)
mx.random.seed(seed)
#configure the profiler
profiler.set_config(profile_all=True, aggregate_stats=True, filename='trace_profile.json')
#start the profiler collecting data
def test_transpose(ndim):
for t in range(20):
dims = [4 for _ in range(ndim)]
dims[-1] *= 4 ** (10 - ndim)
axes = list(range(ndim))
random.shuffle(axes)
axes = tuple(axes)
x = mx.nd.array(np.random.normal(size=dims))
y = mx.nd.transpose(x, axes=axes)
assert_allclose(np.transpose(x.asnumpy(), axes=axes), y.asnumpy())
for ndim in range(1, 10):
profiler.set_state('run')
tic = time.time()
test_transpose(ndim)
print(ndim, "====", time.time() - tic)
#stop the profiler
profiler.set_state('stop')
#dump the profiling data as a string
print(profiler.dumps(reset=True))
print("Over")
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org