You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by in...@apache.org on 2018/11/13 23:38:25 UTC
[incubator-mxnet] branch master updated: [Example] Fixing Gradcam
implementation (#13196)
This is an automated email from the ASF dual-hosted git repository.
indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e655f62 [Example] Fixing Gradcam implementation (#13196)
e655f62 is described below
commit e655f62bcccdf55fbc62b96cd6b12e7fbe68aaba
Author: Ankit Khedia <36...@users.noreply.github.com>
AuthorDate: Tue Nov 13 15:38:06 2018 -0800
[Example] Fixing Gradcam implementation (#13196)
* fixing gradcam
* changed loading parameters code
* fixing type conversions issue with previous versions of matplotlib
---
docs/tutorials/vision/cnn_visualization.md | 3 ++-
example/cnn_visualization/gradcam.py | 4 ++--
example/cnn_visualization/vgg.py | 16 +++++++++++-----
3 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/docs/tutorials/vision/cnn_visualization.md b/docs/tutorials/vision/cnn_visualization.md
index 940c261..a350fff 100644
--- a/docs/tutorials/vision/cnn_visualization.md
+++ b/docs/tutorials/vision/cnn_visualization.md
@@ -151,7 +151,8 @@ def show_images(pred_str, images):
for i in range(num_images):
fig.add_subplot(rows, cols, i+1)
plt.xlabel(titles[i])
- plt.imshow(images[i], cmap='gray' if i==num_images-1 else None)
+ img = images[i].astype(np.uint8)
+ plt.imshow(img, cmap='gray' if i==num_images-1 else None)
plt.show()
```
diff --git a/example/cnn_visualization/gradcam.py b/example/cnn_visualization/gradcam.py
index a8708f7..54cb65e 100644
--- a/example/cnn_visualization/gradcam.py
+++ b/example/cnn_visualization/gradcam.py
@@ -249,8 +249,8 @@ def visualize(net, preprocessed_img, orig_img, conv_layer_name):
imggrad = get_image_grad(net, preprocessed_img)
conv_out, conv_out_grad = get_conv_out_grad(net, preprocessed_img, conv_layer_name=conv_layer_name)
- cam = get_cam(imggrad, conv_out)
-
+ cam = get_cam(conv_out_grad, conv_out)
+ cam = cv2.resize(cam, (imggrad.shape[1], imggrad.shape[2]))
ggcam = get_guided_grad_cam(cam, imggrad)
img_ggcam = grad_to_image(ggcam)
diff --git a/example/cnn_visualization/vgg.py b/example/cnn_visualization/vgg.py
index b6215a3..a8a0ef6 100644
--- a/example/cnn_visualization/vgg.py
+++ b/example/cnn_visualization/vgg.py
@@ -72,11 +72,17 @@ def get_vgg(num_layers, pretrained=False, ctx=mx.cpu(),
root=os.path.join('~', '.mxnet', 'models'), **kwargs):
layers, filters = vgg_spec[num_layers]
net = VGG(layers, filters, **kwargs)
- if pretrained:
- from mxnet.gluon.model_zoo.model_store import get_model_file
- batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
- net.load_params(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix),
- root=root), ctx=ctx)
+ net.initialize(ctx=ctx)
+
+ # Get the pretrained model
+ vgg = mx.gluon.model_zoo.vision.get_vgg(num_layers, pretrained=True, ctx=ctx)
+
+ # Set the parameters in the new network
+ params = vgg.collect_params()
+ for key in params:
+ param = params[key]
+ net.collect_params()[net.prefix+key.replace(vgg.prefix, '')].set_data(param.data())
+
return net
def vgg16(**kwargs):