You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by in...@apache.org on 2018/11/13 23:38:25 UTC

[incubator-mxnet] branch master updated: [Example] Fixing Gradcam implementation (#13196)

This is an automated email from the ASF dual-hosted git repository.

indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e655f62  [Example] Fixing Gradcam implementation (#13196)
e655f62 is described below

commit e655f62bcccdf55fbc62b96cd6b12e7fbe68aaba
Author: Ankit Khedia <36...@users.noreply.github.com>
AuthorDate: Tue Nov 13 15:38:06 2018 -0800

    [Example] Fixing Gradcam implementation (#13196)
    
    * fixing gradcam
    
    * changed loading parameters code
    
    * fixing type conversions issue with previous versions of matplotlib
---
 docs/tutorials/vision/cnn_visualization.md |  3 ++-
 example/cnn_visualization/gradcam.py       |  4 ++--
 example/cnn_visualization/vgg.py           | 16 +++++++++++-----
 3 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/docs/tutorials/vision/cnn_visualization.md b/docs/tutorials/vision/cnn_visualization.md
index 940c261..a350fff 100644
--- a/docs/tutorials/vision/cnn_visualization.md
+++ b/docs/tutorials/vision/cnn_visualization.md
@@ -151,7 +151,8 @@ def show_images(pred_str, images):
     for i in range(num_images):
         fig.add_subplot(rows, cols, i+1)
         plt.xlabel(titles[i])
-        plt.imshow(images[i], cmap='gray' if i==num_images-1 else None)
+        img = images[i].astype(np.uint8)
+        plt.imshow(img, cmap='gray' if i==num_images-1 else None)
     plt.show()
 ```
 
diff --git a/example/cnn_visualization/gradcam.py b/example/cnn_visualization/gradcam.py
index a8708f7..54cb65e 100644
--- a/example/cnn_visualization/gradcam.py
+++ b/example/cnn_visualization/gradcam.py
@@ -249,8 +249,8 @@ def visualize(net, preprocessed_img, orig_img, conv_layer_name):
     imggrad = get_image_grad(net, preprocessed_img)
     conv_out, conv_out_grad = get_conv_out_grad(net, preprocessed_img, conv_layer_name=conv_layer_name)
 
-    cam = get_cam(imggrad, conv_out)
-    
+    cam = get_cam(conv_out_grad, conv_out)
+    cam = cv2.resize(cam, (imggrad.shape[1], imggrad.shape[2]))
     ggcam = get_guided_grad_cam(cam, imggrad)
     img_ggcam = grad_to_image(ggcam)
     
diff --git a/example/cnn_visualization/vgg.py b/example/cnn_visualization/vgg.py
index b6215a3..a8a0ef6 100644
--- a/example/cnn_visualization/vgg.py
+++ b/example/cnn_visualization/vgg.py
@@ -72,11 +72,17 @@ def get_vgg(num_layers, pretrained=False, ctx=mx.cpu(),
             root=os.path.join('~', '.mxnet', 'models'), **kwargs):
     layers, filters = vgg_spec[num_layers]
     net = VGG(layers, filters, **kwargs)
-    if pretrained:
-        from mxnet.gluon.model_zoo.model_store import get_model_file
-        batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
-        net.load_params(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix),
-                                       root=root), ctx=ctx)
+    net.initialize(ctx=ctx)
+    
+    # Get the pretrained model
+    vgg = mx.gluon.model_zoo.vision.get_vgg(num_layers, pretrained=True, ctx=ctx)
+    
+    # Set the parameters in the new network
+    params = vgg.collect_params()
+    for key in params:
+        param = params[key]
+        net.collect_params()[net.prefix+key.replace(vgg.prefix, '')].set_data(param.data())
+   
     return net
 
 def vgg16(**kwargs):