You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by se...@apache.org on 2018/06/30 01:28:01 UTC

[incubator-mxnet] branch master updated: MXNET-336 [Perl] Major Gluon update for Perl API. (#11414)

This is an automated email from the ASF dual-hosted git repository.

sergeykolychev pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cca0883  MXNET-336 [Perl] Major Gluon update for Perl API. (#11414)
cca0883 is described below

commit cca088366849ec803a9e73501d8f4e706e2552ae
Author: Sergey Kolychev <se...@gmail.com>
AuthorDate: Fri Jun 29 18:27:54 2018 -0700

    MXNET-336 [Perl] Major Gluon update for Perl API. (#11414)
    
    * MXNET-336
    Major Gluon update towards parity with Python's API.
    Miscellaneous bugfixes and improvements.
    New Engine API.
    Module::reshape moved to C++ backend.
    Examples were updated to work on multi-gpu boxes.
    
    * fixing random seed for flaky tests.
    
    * removed redundant row.
    
    * fixed learning rate.
---
 perl-package/AI-MXNet/Changes                      |   7 +
 perl-package/AI-MXNet/MANIFEST                     |   4 +
 perl-package/AI-MXNet/META.json                    |   6 +-
 perl-package/AI-MXNet/META.yml                     |   6 +-
 perl-package/AI-MXNet/README                       |   2 +-
 perl-package/AI-MXNet/examples/char_lstm.pl        |  11 +-
 perl-package/AI-MXNet/examples/gluon/mnist.pl      |   4 +-
 perl-package/AI-MXNet/lib/AI/MXNet.pm              |   8 +-
 perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm     |   2 +-
 perl-package/AI-MXNet/lib/AI/MXNet/Base.pm         |  17 +-
 perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm     |  22 +-
 perl-package/AI-MXNet/lib/AI/MXNet/Context.pm      |  44 +
 perl-package/AI-MXNet/lib/AI/MXNet/Engine.pm       |  84 ++
 perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm     | 160 ++--
 .../AI-MXNet/lib/AI/MXNet/Executor/Group.pm        |   2 +-
 perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm  | 789 ++++++++++++++----
 perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN.pm     |   1 +
 .../AI-MXNet/lib/AI/MXNet/Gluon/NN/Activation.pm   | 249 ++++++
 .../AI-MXNet/lib/AI/MXNet/Gluon/NN/BasicLayers.pm  | 485 ++++++++---
 .../AI-MXNet/lib/AI/MXNet/Gluon/NN/ConvLayers.pm   |  55 +-
 .../AI-MXNet/lib/AI/MXNet/Gluon/Parameter.pm       | 557 +++++++++++--
 perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN.pm    |   2 +-
 .../AI-MXNet/lib/AI/MXNet/Gluon/RNN/Cell.pm        |  19 +-
 .../AI-MXNet/lib/AI/MXNet/Gluon/Trainer.pm         | 385 +++++++--
 perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Utils.pm  |  53 +-
 perl-package/AI-MXNet/lib/AI/MXNet/Module.pm       |  16 +
 perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm      |  67 +-
 perl-package/AI-MXNet/lib/AI/MXNet/Optimizer.pm    |  12 +-
 perl-package/AI-MXNet/lib/AI/MXNet/Random.pm       |  34 +-
 perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm       |  12 +
 .../AI-MXNet/lib/AI/MXNet/Visualization.pm         |  12 +
 perl-package/AI-MXNet/t/test_conv.t                |   2 +-
 perl-package/AI-MXNet/t/test_cuda_module.t         |   2 +-
 .../{lib/AI/MXNet/AutoLoad.pm => t/test_engine.t}  |  30 +-
 perl-package/AI-MXNet/t/test_executor.t            |   8 +-
 perl-package/AI-MXNet/t/test_gluon.t               | 909 +++++++++++++++++++--
 perl-package/AI-MXNet/t/test_gluon_trainer.t       | 253 ++++++
 perl-package/AI-MXNet/t/test_loss.t                |  12 +-
 perl-package/AI-MXNet/t/test_ndarray.t             |  81 +-
 perl-package/AI-MXNet/t/test_optimizers.t          |   3 +-
 perl-package/AI-MXNet/t/test_random.t              |  12 +-
 perl-package/AI-MXNet/t/test_symbol.t              |  13 +-
 perl-package/AI-MXNetCAPI/Changes                  |   3 +
 perl-package/AI-MXNetCAPI/META.json                |   2 +-
 perl-package/AI-MXNetCAPI/META.yml                 |   2 +-
 perl-package/AI-MXNetCAPI/README                   |   2 +-
 perl-package/AI-MXNetCAPI/lib/AI/MXNetCAPI.pm      |   2 +-
 perl-package/AI-MXNetCAPI/mxnet.i                  | 106 +++
 perl-package/AI-MXNetCAPI/mxnet_typemaps.i         |  75 +-
 perl-package/AI-NNVMCAPI/Changes                   |   3 +
 perl-package/AI-NNVMCAPI/META.json                 |   2 +-
 perl-package/AI-NNVMCAPI/META.yml                  |   2 +-
 perl-package/AI-NNVMCAPI/README                    |   2 +-
 perl-package/AI-NNVMCAPI/lib/AI/NNVMCAPI.pm        |   2 +-
 perl-package/AI-NNVMCAPI/nnvm_typemaps.i           |   5 +
 55 files changed, 3914 insertions(+), 746 deletions(-)

diff --git a/perl-package/AI-MXNet/Changes b/perl-package/AI-MXNet/Changes
index 6f41751..3939872 100644
--- a/perl-package/AI-MXNet/Changes
+++ b/perl-package/AI-MXNet/Changes
@@ -1,5 +1,12 @@
 Revision history for Perl extension AI::MXNet
 
+1.3     Tue Jun 26 20:57:40 PDT 2018
+        - Major Gluon update towards parity with Python's API.
+        - Miscellaneous bugfixes and improvements.
+        - New Engine API.
+        - Module::reshape moved to C++ backend.
+        - Examples were updated to work on multi-gpu boxes
+
 1.23    Thu Apr 19 15:38:10 PDT 2018
         - Support for image operations on symbols and ndarrays.
 
diff --git a/perl-package/AI-MXNet/MANIFEST b/perl-package/AI-MXNet/MANIFEST
index 5251495..22a1624 100644
--- a/perl-package/AI-MXNet/MANIFEST
+++ b/perl-package/AI-MXNet/MANIFEST
@@ -25,6 +25,7 @@ lib/AI/MXNet/Contrib.pm
 lib/AI/MXNet/Contrib/NDArray.pm
 lib/AI/MXNet/Contrib/Symbol.pm
 lib/AI/MXNet/CudaModule.pm
+lib/AI/MXNet/Engine.pm
 lib/AI/MXNet/Executor.pm
 lib/AI/MXNet/Executor/Group.pm
 lib/AI/MXNet/Function/Parameters.pm
@@ -38,6 +39,7 @@ lib/AI/MXNet/Gluon/Data/Vision.pm
 lib/AI/MXNet/Gluon/Loss.pm
 lib/AI/MXNet/Gluon/Mouse.pm
 lib/AI/MXNet/Gluon/NN.pm
+lib/AI/MXNet/Gluon/NN/Activation.pm
 lib/AI/MXNet/Gluon/NN/BasicLayers.pm
 lib/AI/MXNet/Gluon/NN/ConvLayers.pm
 lib/AI/MXNet/Gluon/Parameter.pm
@@ -97,10 +99,12 @@ t/test_autograd.t
 t/test_base.t
 t/test_conv.t
 t/test_cuda_module.t
+t/test_engine.t
 t/test_executor.t
 t/test_gluon.t
 t/test_gluon_data.t
 t/test_gluon_rnn.t
+t/test_gluon_trainer.t
 t/test_infer_shape.t
 t/test_init.t
 t/test_io.t
diff --git a/perl-package/AI-MXNet/META.json b/perl-package/AI-MXNet/META.json
index 12b8bb7..fabbd77 100644
--- a/perl-package/AI-MXNet/META.json
+++ b/perl-package/AI-MXNet/META.json
@@ -30,8 +30,8 @@
       },
       "runtime" : {
          "requires" : {
-            "AI::MXNetCAPI" : "1.2",
-            "AI::NNVMCAPI" : "1.2",
+            "AI::MXNetCAPI" : "1.3",
+            "AI::NNVMCAPI" : "1.3",
             "Function::Parameters" : "1.0705",
             "Hash::Ordered" : "0.012",
             "GraphViz" : "2.14",
@@ -45,5 +45,5 @@
       }
    },
    "release_status" : "stable",
-   "version" : "1.23"
+   "version" : "1.3"
 }
diff --git a/perl-package/AI-MXNet/META.yml b/perl-package/AI-MXNet/META.yml
index 51320cf..610704f 100644
--- a/perl-package/AI-MXNet/META.yml
+++ b/perl-package/AI-MXNet/META.yml
@@ -17,12 +17,12 @@ no_index:
     - t
     - inc
 requires:
-  AI::MXNetCAPI: '1.2'
-  AI::NNVMCAPI: '1.2'
+  AI::MXNetCAPI: '1.3'
+  AI::NNVMCAPI: '1.3'
   Function::Parameters: '1.0705'
   Hash::Ordered: '0.012'
   GraphViz: '2.14'
   Mouse: v2.1.0
   PDL: '2.007'
   PDL::CCS: '1.23.4'
-version: '1.23'
+version: '1.3'
diff --git a/perl-package/AI-MXNet/README b/perl-package/AI-MXNet/README
index f0a1adb..cc6372c 100644
--- a/perl-package/AI-MXNet/README
+++ b/perl-package/AI-MXNet/README
@@ -1,5 +1,5 @@
 This archive contains the distribution AI-MXNet,
-version 1.23:
+version 1.3:
 
   Perl interface to MXNet machine learning library
 
diff --git a/perl-package/AI-MXNet/examples/char_lstm.pl b/perl-package/AI-MXNet/examples/char_lstm.pl
index 9a80dda..1e9c385 100755
--- a/perl-package/AI-MXNet/examples/char_lstm.pl
+++ b/perl-package/AI-MXNet/examples/char_lstm.pl
@@ -233,22 +233,23 @@ $model->fit(
     initializer         => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
     num_epoch           => $num_epoch,
     batch_end_callback  => mx->callback->Speedometer($batch_size, $disp_batches),
-    ($chkp_epoch ? (epoch_end_callback  => [mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch), \&sample]) : ())
+    ($chkp_epoch ? (epoch_end_callback  => [mx->callback->module_checkpoint($model, $chkp_prefix, $chkp_epoch), \&sample]) : ())
 );
 
+my $chkp = 1;
 sub sample {
     return if not $sample_size;
-    $model->reshape(data_shapes=>[['data',[1, $seq_size]]], label_shapes=>[['softmax_label',[1, $seq_size]]]);
+    my $inference_model = mx->mod->Module->load($chkp_prefix, $chkp++);
+    $inference_model->bind(data_shapes=>[['data',[1, $seq_size]]], label_shapes=>[['softmax_label',[1, $seq_size]]]);
     my $input = mx->nd->array($fdata->slice([0, $seq_size-1]))->reshape([1, $seq_size]);
     $| = 1;
     for (0..$sample_size-1)
     {
-        $model->forward(mx->io->DataBatch(data=>[$input]), is_train => 0);
-        my $prob = $model->get_outputs(0)->[0][0]->at($seq_size-1)->aspdl;
+        $inference_model->forward(mx->io->DataBatch(data=>[$input]), is_train => 0);
+        my $prob = $inference_model->get_outputs(0)->[0][0]->at($seq_size-1)->aspdl;
         my $next_char = Math::Random::Discrete->new($prob->reshape(-1)->unpdl, [0..scalar(keys %vocabulary)-1])->rand;
         print "$reverse_vocab{$next_char}";
         $input->at(0)->slice([0, $seq_size-2]) .= $input->at(0)->slice([1, $seq_size-1])->copy;
         $input->at(0)->at($seq_size-1) .= $next_char;
     }
-    $model->reshape(data_shapes=>[['data',[$batch_size, $seq_size]]], label_shapes=>[['softmax_label',[$batch_size, $seq_size]]]);
 }
diff --git a/perl-package/AI-MXNet/examples/gluon/mnist.pl b/perl-package/AI-MXNet/examples/gluon/mnist.pl
index 2d4eff0..5492e7e 100755
--- a/perl-package/AI-MXNet/examples/gluon/mnist.pl
+++ b/perl-package/AI-MXNet/examples/gluon/mnist.pl
@@ -48,7 +48,7 @@ $net->name_scope(sub {
     $net->add(nn->Dense(10));
 });
 $net->hybridize() if $hybridize;
-$net->load_params('mnist.params') if $load_params;
+$net->load_parameters('mnist.params') if $load_params;
 # data
 
 sub transformer
@@ -130,7 +130,7 @@ sub train
         my ($val_name, $val_acc) = test($ctx);
         print "[Epoch $epoch] Validation: $val_name=$val_acc\n"
     }
-    $net->save_params('mnist.params');
+    $net->save_parameters('mnist.params');
 }
 
 train($epochs, $cuda ? mx->gpu(0) : mx->cpu);
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet.pm b/perl-package/AI-MXNet/lib/AI/MXNet.pm
index 4de57b5..b9ae39c 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet.pm
@@ -50,7 +50,8 @@ use AI::MXNet::AutoGrad;
 use AI::MXNet::Gluon;
 use AI::MXNet::NDArray::Sparse;
 use AI::MXNet::Symbol::Sparse;
-our $VERSION = '1.22';
+use AI::MXNet::Engine;
+our $VERSION = '1.3';
 
 sub import
 {
@@ -76,6 +77,7 @@ sub import
             sub Context { shift; AI::MXNet::Context->new(\@_) }
             sub context { 'AI::MXNet::Context' }
             sub cpu { AI::MXNet::Context->cpu(\$_[1]//0) }
+            sub cpu_pinned { AI::MXNet::Context->cpu_pinned(\$_[1]//0) }
             sub gpu { AI::MXNet::Context->gpu(\$_[1]//0) }
             sub kv { 'AI::MXNet::KVStore' }
             sub recordio { 'AI::MXNet::RecordIO' }
@@ -92,8 +94,10 @@ sub import
             sub contrib { 'AI::MXNet::Contrib' }
             sub linalg { 'AI::MXNet::LinAlg' }
             sub autograd { 'AI::MXNet::AutoGrad' }
+            sub engine { 'AI::MXNet::Engine' }
             sub name { '$short_name' }
             sub rtc { '$short_name' }
+            sub gluon { 'AI::MXNet::Gluon' }
             sub CudaModule { shift; AI::MXNet::CudaModule->new(\@_) }
             sub AttrScope { shift; AI::MXNet::Symbol::AttrScope->new(\@_) }
             *AI::MXNet::Symbol::AttrScope::current = sub { \$${short_name}::AttrScope; };
@@ -106,6 +110,8 @@ sub import
             *AI::MXNet::Context::current_context = sub { \$${short_name}::Context; };
             *AI::MXNet::Context::set_current = sub { \$${short_name}::Context = \$_[1]; };
             \$${short_name}::Context = AI::MXNet::Context->new(device_type => 'cpu', device_id => 0);
+            package nd;
+            \@nd::ISA = ('AI::MXNet::NDArray');
             1;
 EOP
             eval $short_name_package;
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm b/perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm
index 3dc7a06..927fd53 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm
@@ -27,7 +27,7 @@ sub AUTOLOAD
     my $sub = "_${prefix}_$name";
     {
         no strict 'refs';
-        *{"$class::$name"} = sub { shift; $real_class->$sub(@_); };
+        *{"${class}::$name"} = sub { shift; $real_class->$sub(@_); };
     }
     goto $class->can($name);
 }
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm
index 33b3d4d..8e65468 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm
@@ -21,8 +21,8 @@ use warnings;
 use PDL;
 use PDL::Types ();
 use PDL::CCS::Nd;
-use AI::MXNetCAPI 1.2;
-use AI::NNVMCAPI 1.2;
+use AI::MXNetCAPI 1.3;
+use AI::NNVMCAPI 1.3;
 use AI::MXNet::Types;
 use Time::HiRes;
 use Scalar::Util qw(blessed);
@@ -169,9 +169,16 @@ sub zip
 
 sub enumerate
 {
-    my ($sub, @arrays) = @_;
-    my $len = @{ $arrays[0] };
-    zip($sub, [0..$len-1], @arrays);
+    if('CODE' eq ref $_[0])
+    {
+        # continue supporting the callback style
+        my $code = shift;
+        my $len = @{ $_[0] };
+        $code->(@$_) for AI::MXNetCAPI::py_zip([0..$len-1], map { \@$_ } @_);
+        return;
+    }
+    my $len = @{ $_[0] };
+    return AI::MXNetCAPI::py_zip([0..$len-1], map { \@$_ } @_);
 }
 
 =head2 product
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm b/perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm
index 9222afb..27ec6dc 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm
@@ -32,10 +32,17 @@ has 'handle'   => (is => 'ro', isa => 'CachedOpHandle', required => 1);
 around BUILDARGS => sub {
     my $orig  = shift;
     my $class = shift;
-    my ($sym) = @_;
+    my ($sym, $flags) = @_;
+    for my $key (keys %$flags)
+    {
+        $flags->{ $key } = "(" .join(", ", map { defined($_) ? $_ : 'None' } @{ $flags->{ $key } }) .")"
+                if ref $flags->{ $key } eq 'ARRAY';
+    }
     my $handle = check_call(
-        AI::MXNetCAPI::CreateCachedOp(
-            $sym->handle
+        AI::MXNetCAPI::CreateCachedOpEx(
+            $sym->handle,
+            scalar(keys %{ $flags//{} }),
+            $flags//{},
         )
     );
     return $class->$orig(handle => $handle);
@@ -84,8 +91,8 @@ sub call
     {
         $out = [];
     }
-    my $output = check_call(
-        AI::MXNetCAPI::InvokeCachedOp(
+    my ($output, $stypes) = check_call(
+        AI::MXNetCAPI::InvokeCachedOpEx(
             $self->handle,
             scalar(@args),
             [map { $_->handle } @args],
@@ -95,11 +102,12 @@ sub call
     return $original_output if defined $original_output;
     if(@$output == 1)
     {
-        return AI::MXNet::NDArray->_ndarray_cls($output->[0]);
+        return AI::MXNet::NDArray->_ndarray_cls($output->[0], 1, $stypes->[0]);
     }
     else
     {
-        return [map { AI::MXNet::NDArray->_ndarray_cls($_) } @$output];
+        my $i = 0;
+        return [map { AI::MXNet::NDArray->_ndarray_cls($_, 1, $stypes->[$i++]) } @$output];
     }
 }
 
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Context.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Context.pm
index d21b690..e116e6e7 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Context.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Context.pm
@@ -19,6 +19,7 @@ package AI::MXNet::Context;
 use strict;
 use warnings;
 use Mouse;
+use AI::MXNet::Base;
 use AI::MXNet::Types;
 use AI::MXNet::Function::Parameters;
 use constant devtype2str => { 1 => 'cpu', 2 => 'gpu', 3 => 'cpu_pinned' };
@@ -111,6 +112,28 @@ method cpu(Int $device_id=0)
     return $self->new(device_type => 'cpu', device_id => $device_id);
 }
 
+=head2 cpu_pinned
+
+    Returns a CPU pinned memory context. Copying from CPU pinned memory to GPU
+    is faster than from normal CPU memory.
+
+    Parameters
+    ----------
+    device_id : int, optional
+        The device id of the device. `device_id` is not needed for CPU.
+        This is included to make interface compatible with GPU.
+
+    Returns
+    -------
+    context : Context
+        The corresponding CPU pinned memory context.
+=cut
+
+method cpu_pinned(Int $device_id=0)
+{
+    return $self->new(device_type => 'cpu_pinned', device_id => $device_id);
+}
+
 =head2 gpu
 
     Returns a GPU context.
@@ -139,6 +162,27 @@ method gpu(Int $device_id=0)
     $default_ctx : AI::MXNet::Context
 =cut
 
+
+=head2 num_gpus
+
+    Query CUDA for the number of GPUs present.
+
+    Raises
+    ------
+    Will raise an exception on any CUDA error.
+
+    Returns
+    -------
+    count : int
+        The number of GPUs.
+
+=cut
+
+method num_gpus()
+{
+    return scalar(check_call(AI::MXNetCAPI::GetGPUCount()));
+}
+
 method current_ctx()
 {
     return $AI::MXNet::current_ctx;
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Engine.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Engine.pm
new file mode 100644
index 0000000..c4ee262
--- /dev/null
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Engine.pm
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+package AI::MXNet::Engine;
+use strict;
+use warnings;
+use AI::MXNet::Function::Parameters;
+use AI::MXNet::Base;
+=head1 NAME
+
+    AI::MXNet::Engine - Engine properties management.
+=cut
+
+=head2 set_bulk_size
+
+    Set size limit on bulk execution.
+
+    Bulk execution bundles many operators to run together.
+    This can improve performance when running a lot of small
+    operators sequentially.
+
+    Parameters
+    ----------
+    $size : int
+        Maximum number of operators that can be bundled in a bulk.
+
+    Returns
+    -------
+    int
+        Previous bulk size.
+=cut
+
+method set_bulk_size(Int $size)
+{
+    return scalar(check_call(AI::MXNetCAPI::EngineSetBulkSize($size)));
+}
+
+
+=head2 bulk
+
+    Bulk execution bundles many operators to run together.
+    This can improve performance when running a lot of small
+    operators sequentially.
+
+    Parameters
+    ----------
+    $size : int
+        Maximum number of operators that can be bundled in a bulk.
+    $sub: CodeRef to execute
+
+    my $x;
+    mx->engine->bulk(10, sub {
+        $x = mx->nd->zeros([1]);
+        for my $i (1..100)
+        {
+            $x += 1;
+        }
+    });
+=cut
+
+method bulk(Int $size, CodeRef $sub)
+{
+    my $prev = __PACKAGE__->set_bulk_size($size);
+    eval { $sub->() };
+    my $err = $@;
+    __PACKAGE__->set_bulk_size($prev) unless $prev == $size;
+    Carp::confess($err) if $err;
+}
+
+1;
\ No newline at end of file
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm
index ebda90f..edcaabe 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Executor.pm
@@ -429,120 +429,70 @@ method copy_params_from(
 
 method reshape(HashRef[Shape] $kwargs, Int :$partial_shaping=0, Int :$allow_up_sizing=0)
 {
-    my ($arg_shapes, undef, $aux_shapes) = $self->_symbol->infer_shape(%{ $kwargs });
-    confess("Insufficient argument shapes provided.")
-        unless defined $arg_shapes;
-    my %new_arg_dict;
-    my %new_grad_dict;
-    my $i = 0;
-    for my $name (@{ $self->_symbol->list_arguments() })
+    my @provided_arg_shape_data;
+    # argument shape index in sdata,
+    # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg
+    my @provided_arg_shape_idx = (0);
+    my @provided_arg_shape_names = ();  # provided argument names
+    while(my ($k, $v) = each %{ $kwargs })
     {
-        my $new_shape = $arg_shapes->[$i];
-        my $arr       = $self->arg_arrays->[$i];
-        my $darr;
-        if(@{ $self->grad_arrays })
+        if(ref $v eq 'ARRAY')
         {
-            $darr = $self->grad_arrays->[$i];
+            push @provided_arg_shape_names, $k;
+            push @provided_arg_shape_data, @{ $v };
+            push @provided_arg_shape_idx, scalar(@provided_arg_shape_data);
         }
-        if(
-            $partial_shaping
-                or
-            exists $kwargs->{ $name }
-                or
-            join(',', @{ $new_shape }) eq join(',', @{ $arr->shape })
-        )
-        {
-            if(AI::MXNet::NDArray->size($new_shape) > $arr->size)
-            {
-                confess(
-                    "New shape of arg:$name larger than original. "
-                    ."First making a big executor and then down sizing it "
-                    ."is more efficient than the reverse."
-                    ."If you really want to up size, set \$allow_up_sizing=1 "
-                    ."to enable allocation of new arrays."
-                ) unless $allow_up_sizing;
-                $new_arg_dict{ $name }  = AI::MXNet::NDArray->empty(
-                    $new_shape,
-                    ctx => $arr->context,
-                    dtype => $arr->dtype
-                );
-                if(defined $darr)
-                {
-                    $new_grad_dict{ $name } = AI::MXNet::NDArray->empty(
-                        $new_shape,
-                        ctx => $darr->context,
-                        dtype => $arr->dtype
-                    );
-                }
-            }
-            else
-            {
-                $new_arg_dict{ $name } = $arr->reshape($new_shape);
-                if(defined $darr)
-                {
-                    $new_grad_dict{ $name } = $darr->reshape($new_shape);
-                }
-            }
-        }
-        else
-        {
-            confess(
-                    "Shape of unspecified array arg:$name changed. "
-                    ."This can cause the new executor to not share parameters "
-                    ."with the old one. Please check for error in network."
-                    ."If this is intended, set partial_shaping=True to suppress this warning."
-            );
-        }
-        $i++;
     }
-    my %new_aux_dict;
-    $i = 0;
-    for my $name (@{ $self->_symbol->list_auxiliary_states() })
+
+    my @ctx_map_keys;
+    my @ctx_map_dev_types;
+    my @ctx_map_dev_ids;
+
+    if(ref $self->_group2ctx eq 'HASH')
     {
-        my $new_shape = $aux_shapes->[$i];
-        my $arr = $self->aux_arrays->[$i];
-        if($partial_shaping or join(',', @{ $new_shape }) eq join (',', @{ $arr->shape }))
-        {
-            if(AI::MXNet::NDArray->size($new_shape) > $arr->size)
-            {
-                confess(
-                    "New shape of arg:$name larger than original. "
-                    ."First making a big executor and then down sizing it "
-                    ."is more efficient than the reverse."
-                    ."If you really want to up size, set \$allow_up_sizing=1 "
-                    ."to enable allocation of new arrays."
-                ) unless $allow_up_sizing;
-                $new_aux_dict{ $name }  = AI::MXNet::NDArray->empty(
-                    $new_shape,
-                    ctx => $arr->context,
-                    dtype => $arr->dtype
-                );
-            }
-            else
-            {
-                $new_aux_dict{ $name } = $arr->reshape($new_shape);
-            }
-        }
-        else
+        while(my ($k, $v) = each %{ $self->_group2ctx })
         {
-            confess(
-                "Shape of unspecified array aux:$name changed. "
-                ."This can cause the new executor to not share parameters "
-                ."with the old one. Please check for error in network."
-                ."If this is intended, set partial_shaping=True to suppress this warning."
-            );
+            push @ctx_map_keys, $k;
+            push @ctx_map_dev_types, $v->device_type_id;
+            push @ctx_map_dev_ids, $v->device_id;
         }
-        $i++;
     }
-    return $self->_symbol->bind(
-                ctx         => $self->_ctx,
-                args        => \%new_arg_dict,
-                args_grad   => \%new_grad_dict,
-                grad_req    => $self->_grad_req,
-                aux_states  => \%new_aux_dict,
-                group2ctx   => $self->_group2ctx,
-                shared_exec => $self
+
+    my $shared_handle = $self->handle;
+
+    my ($in_args_and_grad_handles, $aux_state_handles, $handle) = check_call(
+        AI::MXNetCAPI::ExecutorReshape(
+            $partial_shaping,
+            $allow_up_sizing,
+            $self->_ctx->device_type_id,
+            $self->_ctx->device_id,
+            scalar(@ctx_map_keys),
+            \@ctx_map_keys,
+            \@ctx_map_dev_types,
+            \@ctx_map_dev_ids,
+            scalar(@provided_arg_shape_names),
+            \@provided_arg_shape_names,
+            \@provided_arg_shape_data,
+            \@provided_arg_shape_idx,
+            $shared_handle
+        )
+    );
+    my ($in_args_handles, $arg_grad_handles) = @{ $in_args_and_grad_handles };
+    my @arg_arrays  = map { AI::MXNet::NDArray->_ndarray_cls($_) } @{ $in_args_handles };
+    my @grad_arrays = map { defined($_) ? AI::MXNet::NDArray->_ndarray_cls($_) : undef } @{ $arg_grad_handles };
+    my @aux_arrays  = map { AI::MXNet::NDArray->_ndarray_cls($_) } @{ $aux_state_handles };
+
+    my $executor = __PACKAGE__->new(
+        handle     => $handle,
+        symbol    => $self->_symbol,
+        ctx       => $self->_ctx,
+        grad_req  => $self->_grad_req,
+        group2ctx => $self->_group2ctx
     );
+    $executor->arg_arrays(\@arg_arrays);
+    $executor->grad_arrays(\@grad_arrays);
+    $executor->aux_arrays(\@aux_arrays);
+    return $executor;
 }
 
 =head2 debug_str
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Executor/Group.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Executor/Group.pm
index acacffd..161a2d7 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Executor/Group.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Executor/Group.pm
@@ -47,7 +47,7 @@ func _split_input_slice($batch_size, $work_load_list)
         $end = int(min($begin + $batch_num, $batch_size));
         if($begin >= $end)
         {
-            confess('Too many slices such that some splits are empty');
+            Carp::confess('Too many slices such that some splits are empty');
         }
         push @slices, [$begin, $end];
     }
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm
index 148df04..60c62cf 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm
@@ -75,6 +75,7 @@ method create($prefix, $params, $hint)
 
 method __enter__()
 {
+    return $self if $self->_block->_empty_prefix;
     $self->_old_scope($_current);
     $_current = $self;
     $self->_name_scope(AI::MXNet::Symbol::NameManager->current);
@@ -84,6 +85,7 @@ method __enter__()
 
 method __exit__()
 {
+    return if $self->_block->_empty_prefix;
     AI::MXNet::Symbol::NameManager->set_current($self->_name_scope);
     $self->_name_scope(undef);
     $_current = $self->_old_scope;
@@ -101,44 +103,53 @@ use AI::MXNet::Gluon::Mouse;
     Base class for all neural network layers and models. Your models should
     subclass this class.
 
-    `Block` can be nested recursively in a tree structure. You can create and
-    assign child `Block` as regular attributes::
+    AI::MXNet::Gluon::Block can be nested recursively in a tree structure. You can create and
+    assign child AI::MXNet::Gluon::Block as regular attributes
 
-        from mxnet.gluon import Block, nn
-        from mxnet import ndarray as F
+    use AI::MXNet::Gluon::NN qw(nn);
+    use AI::MXNet qw(mx);
 
-        class Model(Block):
-            def __init__(self, **kwargs):
-                super(Model, self).__init__(**kwargs)
-                # use name_scope to give child Blocks appropriate names.
-                # It also allows sharing Parameters between Blocks recursively.
-                with self.name_scope():
-                    self.dense0 = nn.Dense(20)
-                    self.dense1 = nn.Dense(20)
+    package Model;
+    use AI::MXNet::Gluon::Mouse;
+    use AI::MXNet::Function::Parameters;
+    extends 'AI::MXNet::Gluon::Block';
 
-                x = F.relu(self.dense0(x))
-                return F.relu(self.dense1(x))
+    sub BUILD
+    {
+        my $self = shift;
+        $self->name_scope(sub {
+            $self->dense0(nn->Dense(5, in_units=>5));
+            $self->dense1(nn->Dense(5, in_units=>5));
+        });
+    }
 
-        model = Model()
-        model.initialize(ctx=mx.cpu(0))
-        model(F.zeros((10, 10), ctx=mx.cpu(0)))
+    method forward($x)
+    {
+        return $self->dense1->($self->dense0->($x));
+    }
 
+    my $model = Model->new()
+    $model->initialize(ctx=>mx->cpu(0))
+    $model->(nd->zeros([10, 10], ctx=>mx->cpu(0)));
 
-    Child `Block` assigned this way will be registered and `collect_params`
+
+    Child AI::MXNet::Gluon::Block assigned this way will be registered and ->collect_params
     will collect their Parameters recursively.
 
     Parameters
     ----------
-    prefix : str
-        Prefix acts like a name space. It will be prepended to the names of all
-        Parameters and child `Block`s in this `Block`'s `name_scope`. Prefix
-        should be unique within one model to prevent name collisions.
-    params : ParameterDict or None
-        `ParameterDict` for sharing weights with the new `Block`. For example,
-        if you want `dense1` to share `dense0`'s weights, you can do::
-
-            dense0 = nn.Dense(20)
-            dense1 = nn.Dense(20, params=dense0.collect_params())
+    Prefix acts like a name space. All children blocks created in parent block's
+    name_scope will have parent block's prefix in their name.
+    Please refer to
+    naming tutorial http://mxnet.incubator.apache.org/tutorials/gluon/naming.html
+    for more info on prefix and naming.
+
+    params : AI::MXNet::Gluon::ParameterDict or undef
+        AI::MXNet::Gluon::ParameterDict for sharing weights with the new AI::MXNet::Gluon::Block. For example,
+        if you want `dense1` to share `dense0`'s weights, you can do
+
+        $dense0 = nn->Dense(20);
+        $dense1 = nn->Dense(20, params=>dense0->collect_params());
 =cut
 
 method _flatten(
@@ -202,8 +213,9 @@ method _regroup(
 
 has _prefix => (is => 'rw', init_arg => 'prefix', isa => 'Str');
 has _params => (is => 'rw', init_arg => 'params', isa => 'Maybe[AI::MXNet::Gluon::ParameterDict]');
-has [qw/_name _scope/] => (is => 'rw', init_arg => undef);
-has [qw/_children/]    => (is => 'rw', init_arg => undef, default => sub { [] });
+has [qw/_name _scope _empty_prefix/] => (is => 'rw', init_arg => undef);
+has [qw/_children _forward_hooks _forward_pre_hooks/]  => (is => 'rw', init_arg => undef, default => sub { Hash::Ordered->new });
+has '_reg_params' => (is => 'rw', init_arg => undef, default => sub { +{} });
 around BUILDARGS => \&AI::MXNet::Base::process_arguments;
 
 sub AUTOLOAD {
@@ -217,6 +229,7 @@ sub AUTOLOAD {
 sub BUILD
 {
     my $self = shift;
+    $self->_empty_prefix(defined $self->_prefix and $self->_prefix eq '');
     my ($prefix, $params) = AI::MXNet::Gluon::BlockScope->create($self->_prefix, $self->_params, $self->_alias);
     $self->_prefix($prefix);
     $self->_params($params);
@@ -255,20 +268,77 @@ method __setattr__($name, $current, $prev=)
                 )
             );
         }
-        if(blessed $current and $current->isa('AI::MXNet::Gluon::Block'))
+    }
+    if(blessed $current and $current->isa('AI::MXNet::Gluon::Block'))
+    {
+        $self->register_child($current, $name);
+    }
+    elsif(blessed $current and $current->isa('AI::MXNet::Gluon::Parameter'))
+    {
+        if(exists $self->_reg_params->{ $name })
         {
-            for(my $i = 0; $i < @{ $self->_children }; $i++)
+            confess("Overriding Parameter attribute $name is not allowed. ".
+                "If you want to share parameters between blocks, please set".
+                "'params' at Block construction instead."
+            );
+        }
+        $self->_reg_params->{ $name } = $current;
+    }
+}
+
+method _check_container_with_block()
+{
+    my $_find_block_in_container;
+    $_find_block_in_container = sub { my ($data) = @_;
+    # Find whether a nested container structure contains Blocks
+        if(ref $data eq 'ARRAY')
+        {
+            for my $ele (@{ $data })
             {
-                if(Scalar::Util::refaddr($self->_children->[$i]) eq Scalar::Util::refaddr($prev))
+                if($_find_block_in_container->($ele))
                 {
-                    $self->_children->[$i] = $current;
+                    return 1
                 }
             }
+            return 0;
         }
-    }
-    if(blessed $current and $current->isa('AI::MXNet::Gluon::Block'))
+        elsif(ref $data eq 'HASH')
+        {
+            for my $v (values %$data)
+            {
+                if($_find_block_in_container->($v))
+                {
+                    return 1;
+                }
+            }
+            return 0;
+        }
+        elsif(blessed $data and $data->isa('AI::MXNet::Gluon::Block'))
+        {
+            return 1;
+        }
+        else
+        {
+            return 0;
+        }
+    };
+    my $attributes_hash = $self->attributes_hash();
+    while(my ($k, $v) = each %{ $attributes_hash })
     {
-        $self->register_child($current);
+        if((ref $v eq 'HASH' or ref $v eq 'ARRAY') and not $k =~ /^__/)
+        {
+            if($_find_block_in_container->($v))
+            {
+                AI::MXNet::Logging->warning(
+                    '"%s" is a container with Blocks. '.
+                    'Note that Blocks inside the list, tuple or dict will not be '.
+                    'registered automatically. Make sure to register them using '.
+                    'register_child() or switching to '.
+                    'nn->Sequential/nn->HybridSequential instead. ',
+                    $self->_class_name.'.'.$k
+                );
+            }
+        }
     }
 }
 
@@ -286,7 +356,7 @@ use overload
     '""' => sub
     {
         my $self = shift;
-        my $s = "%s(\n{%s}\n)";
+        my $s = "%s(\n%s\n)";
         my @blocks;
         my %attributes_hash = %{ $self->attributes_hash };
         while(my ($k, $v) = each %attributes_hash)
@@ -296,7 +366,7 @@ use overload
                 push @blocks, "  ($k): ".AI::MXNet::Base::_indent("$v", 2);
             }
         }
-        sprintf("%s(\n{%s}\n)", $self->_class_name, join("\n", @blocks));
+        sprintf("%s(\n%s\n)", $self->_class_name, join("\n", @blocks));
     },
     '&{}' => sub { my $self = shift; sub { $self->call(@_) } };
 
@@ -318,8 +388,10 @@ method class()
 method name_scope(CodeRef $sub)
 {
     $self->_scope->__enter__;
-    $sub->();
+    eval { $sub->(); };
+    my $err = $@;
     $self->_scope->__exit__;
+    confess($err) if $err;
 }
 
 =head2 params
@@ -335,22 +407,51 @@ method params()
 
 =head2 collect_params
 
-        Returns a `ParameterDict` containing this `Block` and all of its
-        children's Parameters.
+        Returns a AI::MXNet::Gluon::ParameterDict containing this AI::MXNet::Gluon::Block and all of its
+        children's Parameters(default), also can returns the ParameterDict
+        with parameters that match a regular expression.
+
+        For example, collects parameters specified in ['conv1_weight', 'conv1_bias', 'fc_weight',
+        'fc_bias'
+
+            $model->collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
+
+        or collects all parameters that have the name end with 'weight' or 'bias', this can be done
+        using regular expressions.
+
+            $model->collect_params('.*weight|.*bias')
+
 =cut
 
-method collect_params()
+method collect_params(Maybe[Str] $select=)
 {
+    $self->_check_container_with_block();
     my $ret = AI::MXNet::Gluon::ParameterDict->new(prefix => $self->_params->prefix);
-    $ret->update($self->params);
-    for my $cld (@{ $self->_children })
+    $ret->update($self->params, $select);
+    for my $cld ($self->_children->values)
     {
-        $ret->update($cld->collect_params());
+        $ret->update($cld->collect_params($select));
     }
     return $ret;
 }
 
-=head2 save
+
+method _collect_params_with_prefix(Str $prefix='')
+{
+    if($prefix)
+    {
+        $prefix .= '.';
+    }
+    my %ret = map { $prefix.$_ => $self->_reg_params->{ $_ } } keys %{ $self->_reg_params };
+    my $iter = $self->_children->iterator;
+    while(my ($name, $child) = $iter->())
+    {
+        %ret = (%ret, %{ $child->_collect_params_with_prefix("$prefix$name") });
+    }
+    return \%ret;
+}
+
+=head2 save_parameters
 
         Save parameters to file.
 
@@ -358,12 +459,14 @@ method collect_params()
             Path to file.
 =cut
 
-method save_params($filename)
+method save_parameters(Str $filename)
 {
-    $self->collect_params->save($filename, $self->prefix);
+    my $params = $self->_collect_params_with_prefix();
+    my %arg_dict = map { $_ => $params->{$_}->_reduce } keys %{ $params };
+    AI::MXNet::NDArray->save($filename, \%arg_dict);
 }
 
-=head2 load
+=head2 load_parameters
 
         Load parameters from file.
 
@@ -378,20 +481,58 @@ method save_params($filename)
             present in this Block.
 =cut
 
-method load_params(
+method load_parameters(
     Str   $filename,
-    Maybe [AI::MXNet::Context|ArrayRef[AI::MXNet::Context]] :$ctx=,
+    AI::MXNet::Context|ArrayRef[AI::MXNet::Context] :$ctx=AI::MXNet::Context->current_ctx,
     Bool  :$allow_missing=0,
     Bool  :$ignore_extra=0
 )
 {
-    $self->collect_params->load(
-        $filename,
-        ($ctx ? (ctx   => $ctx) : ()),
-        allow_missing  => $allow_missing,
-        ignore_extra   => $ignore_extra,
-        restore_prefix => $self->prefix
-    );
+    my $loaded = AI::MXNet::NDArray->load($filename);
+    my $params = $self->_collect_params_with_prefix;
+    return if not keys %$loaded and not keys %$params;
+
+    if(not grep { /\./ } keys %$loaded)
+    {
+        # legacy loading
+        %$loaded = ();
+        $self->collect_params->load(
+            $filename,
+            ($ctx ? (ctx   => $ctx) : ()),
+            allow_missing  => $allow_missing,
+            ignore_extra   => $ignore_extra,
+            restore_prefix => $self->prefix
+        );
+        return;
+    }
+
+    if(not $allow_missing)
+    {
+        for my $name (keys %$params)
+        {
+            if(not exists $loaded->{$name})
+            {
+                confess(
+                    "Parameter $name is missing in file $filename, which contains parameters:".
+                    join(',', keys %$loaded)."\n".
+                    "Set allow_missing=>1 to ignore missing parameters."
+                );
+            }
+        }
+    }
+    for my $name (keys %$loaded)
+    {
+        if(not $ignore_extra and not exists $params->{ $name })
+        {
+            confess(
+                "Parameter $name loaded from file $filename is not present in ParameterDict, ".
+                "which contains parameters ".
+                join(',', keys %$params)."\n".
+                "Set ignore_extra=>1 to ignore."
+            );
+        }
+        $params->{$name}->_load_init($loaded->{$name}, $ctx) if exists $params->{$name};
+    }
 }
 
 =head2 register_child
@@ -400,25 +541,111 @@ method load_params(
         attributes will be registered automatically.
 =cut
 
-method register_child(AI::MXNet::Gluon::Block $block)
+method register_child(AI::MXNet::Gluon::Block $block, Maybe[Str] $name=)
+{
+    $name //= $self->_children->keys;
+    $self->_children->set($name, $block);
+}
+
+=head2 register_forward_pre_hook
+
+        Registers a forward pre-hook on the block.
+
+        The hook function is called immediately before 'forward'.
+        It should not modify the input or output.
+
+        Parameters
+        ----------
+        $hook : CodeRef or callable object
+            The forward hook function of form $hook->($block, $input).
+
+        Returns
+        -------
+        AI::MXNet::Gluon::Utils::HookHandle
+=cut
+
+method register_forward_pre_hook($hook)
+{
+    my $handle = AI::MXNet::Gluon::Utils::HookHandle->new;
+    $handle->attach($self->_forward_pre_hooks, $hook);
+    return $handle;
+}
+
+=head2 register_forward_hook
+
+        Registers a forward hook on the block.
+
+        The hook function is called immediately after 'forward'.
+        It should not modify the input or output.
+
+        Parameters
+        ----------
+        $hook : CodeRef or callable object
+            The forward hook function of form $hook->($block, $input).
+
+        Returns
+        -------
+        AI::MXNet::Gluon::Utils::HookHandle
+=cut
+
+method register_forward_hook($hook)
+{
+    my $handle = AI::MXNet::Gluon::Utils::HookHandle->new;
+    $handle->attach($self->_forward_hooks, $hook);
+    return $handle;
+}
+
+=head2 apply
+
+        Applies $fn recursively to every child block as well as self.
+
+        Parameters
+        ----------
+        $fn : callable
+            Function to be applied to each submodule, of form `$fn->($block)`.
+
+        Returns
+        -------
+        this block
+=cut
+
+method apply($fn)
 {
-    push @{ $self->_children }, $block;
+    for my $cld ($self->_children->values)
+    {
+        $cld->apply($fn);
+    }
+    $fn->($self);
+    return $self;
 }
 
 =head2 initialize
 
-        Initializes `Parameter`s of this `Block` and its children.
 
-        Equivalent to `block.collect_params().initialize(...)`
+        Initializes AI::MXNet::Gluon::Parameters of this AI::MXNet::Gluon::Block and its children.
+        Equivalent to $block->collect_params()->initialize(...)
+
+        Parameters
+        ----------
+        $init : Initializer
+            Global default Initializer to be used when Parameter->init is undefined`.
+            Otherwise, Parameter->init takes precedence.
+        ctx : Context or array ref of Context
+            Keeps a copy of Parameters on one or many context(s).
+        verbose : bool, default False
+            Whether to verbosely print out details on initialization.
+        force_reinit : bool, default False
+            Whether to force re-initialization if parameter is already initialized.
 =cut
 
 method initialize(
     Initializer $init=AI::MXNet::Initializer->Uniform(),
     AI::MXNet::Context|ArrayRef[AI::MXNet::Context] :$ctx=AI::MXNet::Context->current_ctx,
-    Bool :$verbose=0
+    Bool :$verbose=0,
+    Bool :$force_reinit=0
 )
 {
-    $self->collect_params->initialize(init => $init, ctx => $ctx, verbose => $verbose);
+    $self->collect_params->initialize(init => $init, ctx => $ctx, verbose => $verbose, force_reinit => $force_reinit);
 }
 
 
@@ -429,18 +656,58 @@ method initialize(
 
         Parameters
         ----------
-        active : bool, default True
+        $active : bool, default True
             Whether to turn hybrid on or off.
+        :$static_alloc : bool, default False
+            Statically allocate memory to improve speed. Memory usage may increase.
+        :$static_shape : bool, default False
+            Optimize for invariant input shapes between iterations. Must also
+            set static_alloc to True. Change of input shapes is still allowed
+            but slower.
 =cut
 
-method hybridize(Bool $active=1)
+method hybridize(
+    Bool $active=1,
+    %args
+)
 {
-    $_->hybridize($active) for @{ $self->_children };
+    $_->hybridize(
+        $active,
+        %args
+    ) for $self->_children->values;
+}
+
+=head2 cast
+
+        Cast this Block to use another data type.
+
+        Parameters
+        ----------
+        dtype : Dtype
+            The new data type.
+=cut
+
+method cast(Dtype $dtype)
+{
+    for my $child ($self->_children->values)
+    {
+        $child->cast($dtype);
+    }
+    $_->cast($dtype) for $self->params->values;
 }
 
 method call(@args)
 {
-    return $self->forward(@args);
+    for my $hook ($self->_forward_pre_hooks->values)
+    {
+        $hook->($self, @args);
+    }
+    my @out = $self->forward(@args);
+    for my $hook ($self->_forward_hooks->values)
+    {
+        $hook->($self, @args);
+    }
+    return wantarray ? @out : $out[0];
 }
 
 =head2 forward
@@ -469,7 +736,6 @@ method register(Str $container)
 __PACKAGE__->register('AI::MXNet::Gluon');
 
 package AI::MXNet::Gluon::HybridBlock;
-
 =head2 NAME
 
     AI::MXNet::Gluon::HybridBlock
@@ -497,38 +763,105 @@ use AI::MXNet::Gluon::Mouse;
 use AI::MXNet::Base;
 extends 'AI::MXNet::Gluon::Block';
 has [qw/
-        _reg_params _cached_graph
-        _cached_op _cached_params
+        _cached_graph
+        _cached_op
         _out_format _in_format
-        _active _in_idx
+        _active _flags _cached_op_args
 /] => (is => 'rw', init_arg => undef);
 
 sub BUILD
 {
     my $self = shift;
-    $self->_reg_params({});
-    $self->_cached_graph([]);
     $self->_active(0);
+    $self->_flags([]);
+    $self->_cached_graph([]);
+    $self->_cached_op_args([]);
 }
 
 method __setattr__($name, $current, $prev=)
 {
     $self->SUPER::__setattr__($name, $current, $prev);
-    if(blessed $current and $current->isa('AI::MXNet::Gluon::Parameter'))
+    if(blessed $current and $current->isa('AI::MXNet::Gluon::HybridBlock'))
     {
-        $self->_reg_params->{ $name } = $current;
+        $self->_clear_cached_op();
     }
 }
 
-method register_child(AI::MXNet::Gluon::HybridBlock $block)
+method register_child(AI::MXNet::Gluon::HybridBlock $block, Maybe[Str] $name=)
 {
-    push @{ $self->_children }, $block;
+    $self->SUPER::register_child($block, $name);
+    $self->_clear_cached_op();
 }
 
-method hybridize(Bool $active=1)
+method hybridize(@args)
 {
+    my $active;
+    if(@args%2)
+    {
+        $active = shift(@args);
+    }
+    else
+    {
+        $active = 1;
+    }
     $self->_active($active);
-    $self->SUPER::hybridize($active);
+    @{ $self->_flags } = @args;
+    $self->_clear_cached_op();
+    if($self->_active and ($self->_forward_hooks or $self->_forward_pre_hooks))
+    {
+        AI::MXNet::Logging->warning(
+            "$self is being hybridized while still having forward hook/pre-hook. ".
+            "If $self is a child of HybridBlock, the hooks will not take effect."
+        );
+    }
+    $self->SUPER::hybridize($self->_active, @args);
+}
+
+method cast(Dtype $dtype)
+{
+    $self->_clear_cached_op;
+    $self->SUPER::cast($dtype);
+}
+
+method  _infer_attrs($infer_fn, $attr, @args)
+{
+    my ($inputs, $out) = $self->_get_graph(@args);
+    my ($args) = __PACKAGE__->_flatten([@args]);
+    my %in;
+    zip(sub {
+        my ($i, $j) = @_;
+        $in{ $i->name } = $j->$attr;
+    }, $inputs, $args);
+    my ($arg_attrs, $aux_attrs);
+    ($arg_attrs, undef, $aux_attrs) = $out->$infer_fn(%in);
+    if(not defined $arg_attrs)
+    {
+        confess($@);
+    }
+    my %sdict;
+    zip(sub {
+        my ($i, $j) = @_;
+        $sdict{ $i } = $j;
+    }, $out->list_arguments, $arg_attrs);
+    zip(sub {
+        my ($i, $j) = @_;
+        $sdict{ $i } = $j;
+    }, $out->list_auxiliary_states, $aux_attrs);
+
+    for my $i ($self->collect_params->values)
+    {
+        $i->$attr($sdict{ $i->name });
+    }
+}
+
+method infer_shape(@args)
+{
+    $self->_infer_attrs('infer_shape', 'shape', @args);
+}
+
+method infer_type(@args)
+{
+    $self->_infer_attrs('infer_type', 'dtype', @args);
 }
 
 method _get_graph(@args)
@@ -539,7 +872,15 @@ method _get_graph(@args)
         my ($in_format, $out_format);
         ($args, $in_format) = __PACKAGE__->_flatten($args);
         $self->_in_format($in_format);
-        my @inputs = map { AI::MXNet::Symbol->var("input_$_") } 0 .. @$args-1;
+        my @inputs; 
+        if(@args > 1)
+        {
+            @inputs = map { AI::MXNet::Symbol->var("data_$_") } 0 .. @$args-1;
+        }
+        else
+        {
+            @inputs = (AI::MXNet::Symbol->var("data"))
+        }
         my ($grouped_inputs) = __PACKAGE__->_regroup(\@inputs, $self->_in_format);
         my %params = map { $_ => $self->_reg_params->{$_}->var } keys %{ $self->_reg_params };
         my @out;
@@ -559,62 +900,75 @@ method _get_graph(@args)
         Infers shape of Parameters from inputs.
 =cut
 
-method infer_shape(@args)
+method _build_cache(@args)
 {
-    my ($inputs, $out) = $self->_get_graph(@args);
-    my $args = \@args;
-    ($args) = __PACKAGE__->_flatten($args);
-    my %in;
-    for(zip($inputs, $args)) {
-        my ($i, $j) = @$_;
-        $in{ $i->name } = $j->shape;
-    }
-    my ($arg_shapes, undef, $aux_shapes) = $out->infer_shape(%in);
-    my %sdict;
-    for(zip($out->list_arguments(), $arg_shapes)) {
-        my ($i, $j) = @$_;
-        $sdict{ $i } = $j;
-    }
-    my %aux;
-    for(zip($out->list_auxiliary_states(), $aux_shapes)) {
-        my ($i, $j) = @$_;
-        $aux{ $i } = $j;
-    }
-    %sdict = (%sdict, %aux);
-    for my $i ($self->collect_params->values)
+    my ($data, $out) = $self->_get_graph(@args);
+    my $i = 0;
+    my %data_names = map { $_->name => $i++ } @{ $data };
+    my $params = $self->collect_params;
+    my $input_names = $out->list_inputs;
+    my %param_names = map { $_ => 1 } $params->keys;
+    my %expected_names = map { $_ => 1 } @{ $input_names };
+    for my $name (keys %expected_names)
     {
-        $i->shape($sdict{ $i->name })
+        assert(
+            (exists $param_names{ $name } or exists $data_names{ $name }),
+            "Unknown input to HybridBlock: $name"
+        );
     }
-}
-
-method _build_cache(@args)
-{
-    my ($inputs, $out) = $self->_get_graph(@args);
-    $self->_cached_op(AI::MXNet::NDArray->CachedOp($out));
-    my %params = %{ $self->collect_params };
-    $self->_cached_params([map { $params{ $_ } } @{ $out->list_inputs }]);
-    assert(
-        (
-            ((keys %params) + (@{ $self->_cached_graph->[0] }))
-                ==
-            @{ $out->list_inputs }
-        ),
-        "Wrong number of inputs."
-    );
-    my %name2pos;
-    enumerate(sub {
-        my ($i, $var) = @_;
-        $name2pos{ $var->name } = $i;
-    }, $inputs);
-    my @in_idx;
+    my $unused = join(', ', map { "$data_names{$_}-th" } grep { !exists $expected_names{ $_ } } keys %data_names);
+    AI::MXNet::Logging->warn(
+        "The $unused input to HybridBlock is not used by any ".
+        "computation. Is this intended?"
+    ) if $unused;
+    $unused = join(', ', grep { !exists $expected_names{ $_ } } keys %param_names);
+    AI::MXNet::Logging->warn(
+        "Parameter %s is not used by any computation. " .
+        "Is this intended?"
+    ) if $unused;
+
+    my @data_indices;
+    my @param_indices;
+    $self->_cached_op_args([]);
     enumerate(sub {
         my ($i, $name) = @_;
-        if(not exists $params{ $name })
+        if(exists $data_names{ $name })
+        {
+            push @data_indices, $i;
+            push @{ $self->_cached_op_args }, [1, $data_names{$name}];
+        }
+        else
         {
-            push @in_idx, [$i, $name2pos{ $name }];
+            push @param_indices, $i;
+            push @{ $self->_cached_op_args }, [0, $params->params->get($name)];
         }
-    }, $out->list_inputs);
-    $self->_in_idx(\@in_idx);
+    }, $input_names);
+    my %flags = (
+        data_indices  => \@data_indices,
+        param_indices => \@param_indices,
+        @{ $self->_flags }
+    );
+    $self->_cached_op(AI::MXNet::CachedOp->new($out, \%flags));
+}
+
+method _deferred_infer_shape(@args)
+{
+    eval {
+        $self->infer_shape(@args)
+    };
+    if($@)
+    {
+        confess(
+            "Deferred initialization failed because shape".
+            " cannot be inferred. $@"
+        );
+    }
+}
+
+method _clear_cached_op()
+{
+    $self->_cached_graph([]);
+    $self->_cached_op(undef);
 }
 
 use Data::Dumper;
@@ -624,32 +978,37 @@ method _call_cached_op(@args)
     {
         $self->_build_cache(@args);
     }
-
+    my $args = [@args];
+    my $fmt;
+    ($args, $fmt) = __PACKAGE__->_flatten($args);
+    assert((Dumper($fmt) eq Dumper($self->_in_format)), "Invalid input format");
     my @cargs;
     eval {
-        @cargs = map { defined($_) ? $_->data() : undef } @{ $self->_cached_params };
+        @cargs = map { (not $_->[0]) ? $_->[1]->data() : $args->[$_->[1]] } @{ $self->_cached_op_args };
     };
     if($@)
     {
         if($@ =~ /DeferredInitializationError/)
         {
-            $self->infer_shape(@args);
-            map { $_->_finish_deferred_init if defined } @{ $self->_cached_params };
-            @cargs = map { defined($_) ? $_->data() : undef } @{ $self->_cached_params };
+            $self->_deferred_infer_shape(@$args);
+            @cargs = ();
+            map {
+                if($_->[0])
+                {
+                    push @cargs, $args->[$_->[1]];
+                }
+                else
+                {
+                    $_->[1]->_finish_deferred_init();
+                    push @cargs, $_->[1]->data;
+                }
+            } @{ $self->_cached_op_args };
         }
         else
         {
             confess($@);
         }
     }
-    my $args = [@args];
-    my $fmt;
-    ($args, $fmt) = __PACKAGE__->_flatten($args);
-    assert((Dumper($fmt) eq Dumper($self->_in_format)), "Invalid input format");
-    for (@{ $self->_in_idx })
-    {
-        $cargs[$_->[0]] = $args->[$_->[1]];
-    }
     my $out = $self->_cached_op->(@cargs);
     if(blessed $out and $out->isa('AI::MXNet::NDArray'))
     {
@@ -704,8 +1063,8 @@ method forward($x, @args)
         {
             if($@ =~ /DeferredInitializationError/)
             {
-                $self->infer_shape($x, @args);
-                $_->_finish_deferred_init for $self->collect_params->values;
+                $self->_deferred_infer_shape($x, @args);
+                $_->_finish_deferred_init for $self->params->values;
                 %params = map { $_ => $self->_reg_params->{ $_ }->data($ctx) } keys %{ $self->_reg_params };
             }
             else
@@ -747,6 +1106,55 @@ method hybrid_forward($F, $x, @args)
     confess("NotImplementedError");
 }
 
+=head2 export
+
+        Export HybridBlock to json format that can be loaded by AI::MXNet::Module
+        or the C++ interface.
+
+        When there are only one input, it will have name 'data'. When there
+        Are more than one inputs, they will be named as `data0`, `data1`, etc.
+
+        Parameters
+        ----------
+        $path : str
+            Path to save model. Two files `path-symbol.json` and `path-xxxx.params`
+            will be created, where xxxx is the 4 digits epoch number.
+        :$epoch=0 : Int
+            Epoch number of saved model.
+=cut
+
+method export(Str $path, :$epoch=0)
+{
+    if(not @{ $self->_cached_graph })
+    {
+        confess(
+            "Please first call \$block->hybridize() and then run forward with ".
+            "this block at least once before calling export."
+        );
+    }
+    my $sym = $self->_cached_graph->[1];
+    $sym->save("$path-symbol.json");
+
+    my %arg_names = map { $_ => 1 } @{ $sym->list_arguments };
+    my %aux_names = map { $_ => 1 } @{ $sym->list_auxiliary_states };
+    my %arg_dict;
+    my $params = $self->collect_params;
+    for my $name ($params->keys)
+    {
+        my $param = $params->get($name);
+        if(exists $arg_names{ $name })
+        {
+            $arg_dict{ "arg:$name" } = $param->_reduce;
+        }
+        else
+        {
+            assert(exists $aux_names{ $name });
+            $arg_dict{ "aux:$name" } = $param->_reduce;
+        }
+    }
+    AI::MXNet::NDArray->save(sprintf('%s-%04d.params', $path, $epoch), \%arg_dict);
+}
+
 __PACKAGE__->register('AI::MXNet::Gluon');
 
 package AI::MXNet::Gluon::SymbolBlock;
@@ -799,15 +1207,16 @@ method python_constructor_arguments() { [qw/outputs inputs/] }
 sub BUILD
 {
     my ($self, $orig_params) = @_;
+    return unless defined $self->outputs and defined $self->inputs;
     $self->_prefix('');
     $self->_params(AI::MXNet::Gluon::ParameterDict->new(prefix => '', shared => $orig_params->{params}));
     if(blessed $self->inputs and @{ $self->inputs->list_outputs } == 1)
     {
         $self->inputs([$self->inputs]);
     }
-    if(blessed $self->outputs and @{ $self->outputs->list_outputs } == 1)
+    if(not blessed $self->outputs and @{ $self->outputs } == 1)
     {
-        $self->outputs([$self->outputs]);
+        $self->outputs($self->outputs->[0]);
     }
     my ($syms, $in_format) = __PACKAGE__->_flatten($self->inputs);
     my ($out, $out_format) = __PACKAGE__->_flatten($self->outputs);
@@ -825,6 +1234,20 @@ sub BUILD
         $input_names{ $i->name } = 1;
     }
 
+    # check if any symbol is row_sparse
+    my $row_sparse_storage = STORAGE_TYPE_STR_TO_ID->{row_sparse};
+    for my $i (@{ $out })
+    {
+        for my $j (@{ $i->get_internals })
+        {
+            assert(
+                (not defined $j->attr("__storage_type__") or $j->attr("__storage_type__") ne $row_sparse_storage),
+                "SymbolBlock doesn't support Parameter ${\ $j->name }  because its storage ".
+                "type is 'row_sparse'."
+            );
+        }
+    }
+
     for my $i (@{ $out->list_arguments })
     {
         if(not exists $input_names{$i})
@@ -842,7 +1265,33 @@ sub BUILD
     }
 
     $self->_cached_graph([$syms, $out]);
-    $self->_build_cache;
+    my $prefix = _common_prefix($self->_params->keys);
+    my %params = $self->_params->items;
+    while(my ($key, $val) = each %params)
+    {
+        $key =~ s/^$prefix//;
+        $self->_reg_params->{ $key } = $val;
+    }
+    $self->_prefix($prefix);
+}
+
+func _common_prefix(@names)
+{
+    if(not @names)
+    {
+        return ''
+    }
+    my $prefix = $names[0];
+    for my $name (@names)
+    {
+        my $i = 0;
+        while($i < length($prefix) and $i < length($name) and substr($prefix, $i, 1) eq substr($name, $i, 1))
+        {
+            $i++;
+        }
+        $prefix = substr($prefix, 0, $i);
+    }
+    return $prefix;
 }
 
 method forward($x, @args)
@@ -894,11 +1343,53 @@ method forward($x, @args)
     }
 }
 
+method _clear_cached_op()
+{
+    my $tmp = $self->_cached_graph;
+    $self->SUPER::_clear_cached_op;
+    $self->_cached_graph($tmp);
+}
+
 method hybrid_forward(@args)
 {
     confess('NotImplementedError');
 }
 
+=head2 imports
+
+        Import model previously saved by HybridBlock->export or
+        Module->save_checkpoint as a SymbolBlock for use in Gluon.
+
+        Parameters
+        ----------
+        $symbol_file : Str
+            Path to symbol file.
+        $input_names : Str|ArrayRef[Str]
+            List of input variable names
+        :$param_file : Str, optional
+            Path to parameter file.
+        $ctx : Context, default undef
+            The context to initialize SymbolBlock on.
+
+        Returns
+        -------
+        SymbolBlock
+            SymbolBlock loaded from symbol and parameter files.
+=cut
+
+method imports(Str $symbol_file, Str|ArrayRef[Str] $input_names, Maybe [Str] $param_file=, Maybe[AI::MXNet::Context] $ctx=)
+{
+    my $sym = AI::MXNet::Symbol->load($symbol_file);
+    $input_names = [$input_names] unless ref $input_names;
+    my @inputs = map { AI::MXNet::Symbol->var($_) } @{ $input_names };
+    my $ret = __PACKAGE__->new($sym, \@inputs);
+    if(defined $param_file)
+    {
+        $ret->load_parameters($param_file, (defined $ctx ? (ctx=>$ctx) : ()));
+    }
+    return $ret
+}
+
 __PACKAGE__->register('AI::MXNet::Gluon');
 
 1;
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN.pm
index 16b0415..673ee5d 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN.pm
@@ -19,6 +19,7 @@ package AI::MXNet::Gluon::NN;
 use strict;
 use warnings;
 use AI::MXNet::Gluon::Block;
+use AI::MXNet::Gluon::NN::Activation;
 use AI::MXNet::Gluon::NN::BasicLayers;
 use AI::MXNet::Gluon::NN::ConvLayers;
 
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/Activation.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/Activation.pm
new file mode 100644
index 0000000..092893a
--- /dev/null
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/Activation.pm
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+package AI::MXNet::Gluon::NN::Activation;
+use strict;
+use warnings;
+use AI::MXNet::Function::Parameters;
+
+=head1
+
+    AI::MXNet::Gluon::NN::Activation
+=cut
+
+=head1 DESCRIPTION
+
+    Applies an activation function to input.
+
+    Parameters
+    ----------
+    activation : str
+        Name of activation function to use.
+        See mxnet.ndarray.Activation for available choices.
+
+    Input shape:
+        Arbitrary.
+
+    Output shape:
+        Same shape as input.
+=cut
+
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+has 'activation' => (is => 'ro', isa => 'Str', required => 1);
+
+method python_constructor_arguments()
+{
+    ['activation'];
+}
+
+method _alias()
+{
+    return $self->activation;
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x)
+{
+    return $F->Activation($x, act_type => $self->activation, name=>'fwd');
+}
+
+use overload '""' => sub { my $self = shift; "${\ $self->_class_name }(${\ $self->activation })"; };
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::LeakyReLU;
+=head1
+
+    AI::MXNet::Gluon::NN::LeakyReLU - Leaky version of a Rectified Linear Unit.
+=cut
+
+=head1 DESCRIPTION
+
+    Leaky version of a Rectified Linear Unit.
+
+    It allows a small gradient when the unit is not active
+
+    Parameters
+    ----------
+    alpha : float
+        slope coefficient for the negative half axis. Must be >= 0.
+=cut
+
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+has 'alpha' => (is => 'ro', isa => 'Num', required => 1);
+
+method python_constructor_arguments()
+{
+    ['alpha'];
+}
+
+sub BUILD
+{
+    confess('Slope coefficient for LeakyReLU must be no less than 0')
+        unless shift->alpha > 0;
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x)
+{
+    return $F->LeakyReLU($x, act_type => 'leaky', slope => $self->alpha, name=>'fwd');
+}
+
+use overload '""' => sub { my $self = shift; "${\ $self->_class_name }(${\ $self->alpha })"; };
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::PReLU;
+=head1
+
+    AI::MXNet::Gluon::NN::PReLU - Parametric leaky version of a Rectified Linear Unit.
+=cut
+
+=head1 DESCRIPTION
+
+    Parametric leaky version of a Rectified Linear Unit.
+    https://arxiv.org/abs/1502.01852
+
+    It learns a gradient when the unit is not active
+
+    Parameters
+    ----------
+    alpha_initializer : Initializer
+        Initializer for the embeddings matrix.
+=cut
+
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+has 'alpha_initializer' => (is => 'ro', isa => 'Initializer', default => sub { AI::MXNet::Constant->new(0.25) });
+
+method python_constructor_arguments()
+{
+    ['alpha_initializer'];
+}
+
+sub BUILD
+{
+    my $self = shift;
+    $self->name_scope(sub {
+        $self->alpha($self->params->get('alpha', shape=>[1], init=>$self->alpha_initializer));
+    });
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x, GluonInput :$alpha)
+{
+    return $F->LeakyReLU($x, gamma => $alpha, act_type => 'prelu',  name=>'fwd');
+}
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::ELU;
+=head1
+
+    AI::MXNet::Gluon::NN::ELU - Exponential Linear Unit (ELU)
+=cut
+
+=head1 DESCRIPTION
+
+    Exponential Linear Unit (ELU)
+        "Fast and Accurate Deep Network Learning by Exponential Linear Units", Clevert et al, 2016
+        https://arxiv.org/abs/1511.07289
+        Published as a conference paper at ICLR 2016
+
+    Parameters
+    ----------
+    alpha : float
+        The alpha parameter as described by Clevert et al, 2016
+=cut
+
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+has 'alpha' => (is => 'ro', isa => 'Num', default => 1);
+
+method python_constructor_arguments()
+{
+    ['alpha'];
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x)
+{
+    return $F->where($x > 0, $x, $self->alpha * ($F->exp($x) - 1));
+}
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::SELU;
+=head1
+
+    AI::MXNet::Gluon::NN::SELU - Scaled Exponential Linear Unit (SELU)
+=cut
+
+=head1 DESCRIPTION
+
+    Scaled Exponential Linear Unit (SELU)
+    "Self-Normalizing Neural Networks", Klambauer et al, 2017
+    https://arxiv.org/abs/1706.02515
+=cut
+
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+
+sub BUILD
+{
+    my $self = shift;
+    $self->scale(1.0507009873554804934193349852946);
+    $self->alpha(1.6732632423543772848170429916717);
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x)
+{
+    return $self->scale * $F->where($x > 0, $x, $self->alpha * ($F->exp($x) - 1));
+}
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::Swish;
+=head1
+
+    AI::MXNet::Gluon::NN::Swish - Swish Activation function
+=cut
+
+=head1 DESCRIPTION
+
+    Swish Activation function
+        https://arxiv.org/pdf/1710.05941.pdf
+
+    Parameters
+    ----------
+    beta : float
+        swish(x) = x * sigmoid(beta*x)
+=cut
+
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+has 'beta' => (is => 'ro', isa => 'Num', default => 1);
+
+method python_constructor_arguments()
+{
+    ['beta'];
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x)
+{
+    return return $x * $F->sigmoid($self->beta * $x, name=>'fwd');
+}
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/BasicLayers.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/BasicLayers.pm
index 6ef4714..9541790 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/BasicLayers.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/BasicLayers.pm
@@ -47,15 +47,15 @@ extends 'AI::MXNet::Gluon::Block';
     Adds block on top of the stack.
 =cut
 
-method add(AI::MXNet::Gluon::Block $block)
+method add(AI::MXNet::Gluon::Block @block)
 {
-    $self->register_child($block);
+    $self->register_child($_) for @block;
 }
 
 
 method forward($x)
 {
-    for my $block (@{ $self->_children })
+    for my $block ($self->_children->values)
     {
         $x = $block->($x);
     }
@@ -66,17 +66,24 @@ use overload
     '""' => sub
     {
         my $self = shift;
-        my $s = "%s(\n{%s}\n)";
+        my $s = "%s(\n%s\n)";
         my @blocks;
         my $k = 0;
-        for my $v (@{ $self->{_children} })
+        for my $v ($self->_children->values)
         {
             push @blocks, "  ($k): ".AI::MXNet::Base::_indent("$v", 2);
             $k++;
         }
-        sprintf("%s(\n{%s}\n)", $self->_class_name, join("\n", @blocks));
+        sprintf("%s(\n%s\n)", $self->_class_name, join("\n", @blocks));
     },
-    '@{}' => sub { shift->_children };
+    '@{}' => sub { [shift->_children->values] };
+
+method slice(Slice $slice)
+{
+    my $new = __PACKAGE__->new;
+    $new->add(@{ $self }[ @$slice ]);
+    return $new;
+}
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
@@ -109,15 +116,15 @@ extends 'AI::MXNet::Gluon::HybridBlock';
     Adds block on top of the stack.
 =cut
 
-method add(AI::MXNet::Gluon::HybridBlock $block)
+method add(AI::MXNet::Gluon::HybridBlock @block)
 {
-    $self->register_child($block);
+    $self->register_child($_) for @block;
 }
 
 
-method forward($x)
+method hybrid_forward($F, $x)
 {
-    for my $block (@{ $self->_children })
+    for my $block ($self->_children->values)
     {
         $x = $block->($x);
     }
@@ -128,17 +135,24 @@ use overload
     '""' => sub
     {
         my $self = shift;
-        my $s = "%s(\n{%s}\n)";
+        my $s = "%s(\n%s\n)";
         my @blocks;
         my $k = 0;
-        for my $v (@{ $self->{_children} })
+        for my $v ($self->_children->values)
         {
             push @blocks, "  ($k): ".AI::MXNet::Base::_indent("$v", 2);
             $k++;
         }
-        sprintf("%s(\n{%s}\n)", $self->_class_name, join("\n", @blocks));
+        sprintf("%s(\n%s\n)", $self->_class_name, join("\n", @blocks));
     },
-    '@{}' => sub { shift->_children };
+    '@{}' => sub { [shift->_children->values] };
+
+method slice(Slice $slice)
+{
+    my $new = __PACKAGE__->new;
+    $new->add(@{ $self }[ @$slice ]);
+    return $new;
+}
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
@@ -293,52 +307,6 @@ use overload '""' => sub {
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
-package AI::MXNet::Gluon::NN::Activation;
-
-=head1 
-
-    AI::MXNet::Gluon::NN::Activation
-=cut
-
-=head1 DESCRIPTION
-
-    Applies an activation function to input.
-
-    Parameters
-    ----------
-    activation : str
-        Name of activation function to use.
-        See mxnet.ndarray.Activation for available choices.
-
-    Input shape:
-        Arbitrary.
-
-    Output shape:
-        Same shape as input.
-=cut
-use AI::MXNet::Gluon::Mouse;
-extends 'AI::MXNet::Gluon::HybridBlock';
-has 'activation' => (is => 'ro', isa => 'Str', required => 1);
-
-method python_constructor_arguments()
-{
-    ['activation'];
-}
-
-method _alias()
-{
-    return $self->activation;
-}
-
-method hybrid_forward(GluonClass $F, GluonInput $x)
-{
-    return $F->Activation($x, act_type => $self->activation, name=>'fwd');
-}
-
-use overload '""' => sub { my $self = shift; "${\ $self->_class_name }(${\ $self->activation })"; };
-
-__PACKAGE__->register('AI::MXNet::Gluon::NN');
-
 package AI::MXNet::Gluon::NN::Dropout;
 use AI::MXNet::Gluon::Mouse;
 extends 'AI::MXNet::Gluon::HybridBlock';
@@ -517,51 +485,6 @@ use overload '""' => sub {
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
-package AI::MXNet::Gluon::NN::LeakyReLU;
-use AI::MXNet::Gluon::Mouse;
-extends 'AI::MXNet::Gluon::HybridBlock';
-
-=head1 NAME
-
-    AI::MXNet::Gluon::NN::LeakyReLU
-=cut
-
-=head1 DESCRIPTION
-
-    Leaky version of a Rectified Linear Unit.
-
-    It allows a small gradient when the unit is not active
-
-        `f(x) = alpha * x for x < 0`,
-        `f(x) = x for x >= 0`.
-
-    Parameters
-    ----------
-    alpha : float
-        slope coefficient for the negative half axis. Must be >= 0.
-
-
-    Input shape:
-        Arbitrary.
-
-    Output shape:
-        Same shape as input.
-=cut
-has 'alpha' => (is => 'ro', isa => 'Num', required => 1);
-method python_constructor_arguments()
-{
-    ['alpha'];
-}
-
-method hybrid_forward(GluonClass $F, GluonInput $x)
-{
-    return $F->LeakyReLU($x, act_type => 'leaky', slope => $self->alpha, name => 'fwd');
-}
-
-use overload '""' => sub { my $self = shift; "${\ $self->_class_name }(${\ $self->alpha })"; };
-
-__PACKAGE__->register('AI::MXNet::Gluon::NN');
-
 package AI::MXNet::Gluon::NN::Embedding;
 use AI::MXNet::Gluon::Mouse;
 extends 'AI::MXNet::Gluon::HybridBlock';
@@ -587,19 +510,15 @@ extends 'AI::MXNet::Gluon::HybridBlock';
         Data type of output embeddings.
     weight_initializer : Initializer
         Initializer for the `embeddings` matrix.
-
-
-    Input shape:
-        2D tensor with shape: `(N, M)`.
-
-    Output shape:
-        3D tensor with shape: `(N, M, output_dim)`.
+    sparse_grad: bool
+        If True, gradient w.r.t. weight will be a 'row_sparse' NDArray.
 =cut
 
 has [qw/input_dim
     output_dim/]         => (is => 'ro', isa => 'DimSize', required => 1);
 has 'dtype'              => (is => 'ro', isa => 'Dtype', default => 'float32');
 has 'weight_initalizer'  => (is => 'ro', isa => 'Maybe[Initializer]');
+has 'sparse_grad'        => (is => 'ro', isa => 'Bool', default => 0);
 has [qw/_kwargs weight/] => (is => 'rw', init_arg => undef);
 method python_constructor_arguments()
 {
@@ -612,14 +531,17 @@ sub BUILD
     $self->_kwargs({
         input_dim => $self->input_dim,
         output_dim =>  $self->output_dim,
-        dtype => $self->dtype
+        dtype => $self->dtype,
+        sparse_grad => $self->sparse_grad
     });
     $self->weight(
         $self->params->get(
             'weight',
             shape => [$self->input_dim, $self->output_dim],
             init => $self->weight_initializer,
-            allow_deferred_init => 1
+            allow_deferred_init => 1,
+            dtype => $self->dtype,
+            grad_stype => ($self->sparse_grad ? 'row_sparse' : 'default')
         )
     );
 }
@@ -665,4 +587,337 @@ use overload '""' => sub { shift->_class_name };
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
+package AI::MXNet::Gluon::NN::InstanceNorm;
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+
+=head1 NAME
+
+    AI::MXNet::Gluon::NN::InstanceNorm - Applies instance normalization to the n-dimensional input array.
+=cut
+
+=head1 DESCRIPTION
+
+    Applies instance normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array where (n>2) and normalizes
+    the input using the following formula:
+
+    Parameters
+    ----------
+    axis : int, default 1
+        The axis that will be excluded in the normalization process. This is typically the channels
+        (C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
+        set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`. Data will be
+        normalized along axes excluding the first axis and the axis given.
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    center: bool, default True
+        If True, add offset of `beta` to normalized tensor.
+        If False, `beta` is ignored.
+    scale: bool, default True
+        If True, multiply by `gamma`. If False, `gamma` is not used.
+        When the next layer is linear (also e.g. `nn.relu`),
+        this can be disabled since the scaling
+        will be done by the next layer.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+    in_channels : int, default 0
+        Number of channels (feature maps) in input data. If not specified,
+        initialization will be deferred to the first time `forward` is called
+        and `in_channels` will be inferred from the shape of input data.
+
+    References
+    ----------
+        Instance Normalization: The Missing Ingredient for Fast Stylization
+        <https://arxiv.org/abs/1607.08022>
+
+    Examples
+    --------
+    >>> # Input of shape (2,1,2)
+    >>> $x = mx->nd->array([[[ 1.1,  2.2]],
+    ...                 [[ 3.3,  4.4]]]);
+    >>> $layer = nn->InstanceNorm()
+    >>> $layer->initialize(ctx=>mx->cpu(0))
+    >>> $layer->($x)
+    [[[-0.99998355  0.99998331]]
+     [[-0.99998319  0.99998361]]]
+    <NDArray 2x1x2 @cpu(0)>
+=cut
+
+has 'axis'              => (is => 'ro', isa => 'Int',  default => 1);
+has 'epsilon'           => (is => 'ro', isa => 'Num',  default => 1e-5);
+has 'center'            => (is => 'ro', isa => 'Bool', default => 1);
+has 'scale'             => (is => 'ro', isa => 'Bool', default => 0);
+has 'beta_initializer'  => (is => 'rw', isa => 'Initializer', default => 'zeros');
+has 'gamma_initializer' => (is => 'rw', isa => 'Initializer', default => 'ones');
+has 'in_channels'       => (is => 'rw', isa => 'Int',  default => 0);
+has [qw/_kwargs
+        gamma beta/]    => (is => 'rw', init_arg => undef);
+method python_constructor_arguments()
+{
+    [qw/axis epsilon center scale beta_initializer gamma_initializer in_channels/];
+}
+
+
+sub BUILD
+{
+    my $self = shift;
+    $self->_kwargs(Hash::Ordered->new(eps => $self->epsilon, axis => $self->axis, center => $self->center, scale => $self->scale));
+    $self->gamma(
+        $self->params->get(
+            'gamma', grad_req => $self->scale ? 'write' :'null',
+            shape => [$self->in_channels], init => $self->gamma_initializer,
+            allow_deferred_init => 1
+        )
+    );
+    $self->beta(
+        $self->params->get(
+            'beta', grad_req => $self->scale ? 'write' :'null',
+            shape => [$self->in_channels], init => $self->beta_initializer,
+            allow_deferred_init => 1
+        )
+    );
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x, GluonInput :$gamma, GluonInput :$beta)
+{
+    if($self->axis == 1)
+    {
+        return $F->InstanceNorm(
+                    $x, $gamma, $beta,
+                    name=>'fwd', eps=>$self->epsilon
+        );
+    }
+    $x = $x->swapaxes(1, $self->axis);
+    return $F->InstanceNorm(
+                    $x, $gamma, $beta, name=>'fwd',
+                    eps => $self->epsilon
+    )->swapaxes(1, $self->axis);
+}
+
+use overload '""' => sub {
+    my $self = shift;
+    my $in_channels = ", in_channels=${\ $self->in_channels }";
+    my $content = join(', ', map { join('=', $_, $self->_kwargs->get($_)) } $self->_kwargs->keys);
+    return "${\ $self->_class_name }($content, $in_channels)";
+};
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::LayerNorm;
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+
+=head1 NAME
+
+    AI::MXNet::Gluon::NN::LayerNorm - Applies layer normalization to the n-dimensional input array.
+=cut
+
+=head1 DESCRIPTION
+
+    Applies layer normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array and normalizes
+    the input using the given axis:
+
+    Parameters
+    ----------
+    axis : int, default -1
+        The axis that should be normalized. This is typically the axis of the channels.
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    center: bool, default True
+        If True, add offset of `beta` to normalized tensor.
+        If False, `beta` is ignored.
+    scale: bool, default True
+        If True, multiply by `gamma`. If False, `gamma` is not used.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+    in_channels : int, default 0
+        Number of channels (feature maps) in input data. If not specified,
+        initialization will be deferred to the first time `forward` is called
+        and `in_channels` will be inferred from the shape of input data.
+
+    References
+    ----------
+        `Layer Normalization
+        <https://arxiv.org/pdf/1607.06450.pdf>`_
+
+    Examples
+    --------
+    >>> # Input of shape (2, 5)
+    >>> $x = mx->nd->array([[1, 2, 3, 4, 5], [1, 1, 2, 2, 2]])
+    >>> # Layer normalization is calculated with the above formula
+    >>> $layer = nn->LayerNorm()
+    >>> $layer->initialize(ctx=>mx->cpu(0))
+    >>> $layer->($x)
+    [[-1.41421    -0.707105    0.          0.707105    1.41421   ]
+     [-1.2247195  -1.2247195   0.81647956  0.81647956  0.81647956]]
+    <NDArray 2x5 @cpu(0)>
+=cut
+
+has 'axis'              => (is => 'ro', isa => 'Int',  default => -1);
+has 'epsilon'          => (is => 'ro', isa => 'Num',  default => 1e-5);
+has 'center'            => (is => 'ro', isa => 'Bool', default => 1);
+has 'scale'             => (is => 'ro', isa => 'Bool', default => 0);
+has 'beta_initializer'  => (is => 'rw', isa => 'Initializer', default => 'zeros');
+has 'gamma_initializer' => (is => 'rw', isa => 'Initializer', default => 'ones');
+has 'in_channels'       => (is => 'rw', isa => 'Int',  default => 0);
+has [qw/_kwargs
+        gamma beta/]    => (is => 'rw', init_arg => undef);
+method python_constructor_arguments()
+{
+    [qw/axis epsilon center scale beta_initializer gamma_initializer in_channels/];
+}
+
+sub BUILD
+{
+    my $self = shift;
+    $self->_kwargs(Hash::Ordered->new(eps => $self->epsilon, axis => $self->axis, center => $self->center, scale => $self->scale));
+    $self->gamma(
+        $self->params->get(
+            'gamma', grad_req => $self->scale ? 'write' :'null',
+            shape => [$self->in_channels], init => $self->gamma_initializer,
+            allow_deferred_init => 1
+        )
+    );
+    $self->beta(
+        $self->params->get(
+            'beta', grad_req => $self->scale ? 'write' :'null',
+            shape => [$self->in_channels], init => $self->beta_initializer,
+            allow_deferred_init => 1
+        )
+    );
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x, GluonInput :$gamma, GluonInput :$beta)
+{
+    return $F->LayerNorm(
+        $x, $gamma, $beta,
+        eps => $self->epsilon, axis => $self->axis
+    );
+}
+
+use overload '""' => sub {
+    my $self = shift;
+    my $in_channels = ", in_channels=${\ $self->in_channels }";
+    my $content = join(', ', map { join('=', $_, $self->_kwargs->get($_)) } $self->_kwargs->keys);
+    return "${\ $self->_class_name }($content, $in_channels)";
+};
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::Lambda;
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::Block';
+
+=head1 NAME
+
+    AI::MXNet::Gluon::NN::Lambda - Wraps an operator or an expression as a Block object.
+=cut
+
+=head1 DESCRIPTION
+
+    Wraps an operator or an expression as a Block object.
+
+    Parameters
+    ----------
+    function : str or sub
+        Function used in lambda must be one of the following:
+        1) the name of an operator that is available in ndarray. For example
+
+            $block = nn->Lambda('tanh')
+
+        2) a sub. For example
+
+            $block = nn->Lambda(sub { my $x = shift; nd->LeakyReLU($x, slope=>0.1) });
+=cut
+
+has '_func_impl' => (is => 'rw', init_arg => 'function', isa => 'Str|CodeRef', required => 1);
+has '_func_name' => (is => 'rw', init_arg => undef, default => 'custom_sub');
+method python_constructor_arguments() { ['function'] }
+
+sub BUILD
+{
+    my $self = shift;
+    if(not ref $self->_func_impl)
+    {
+        confess("Function name ${\ $self->_func_impl } is not found in ndarray.")
+            unless AI::MXNet::NDArray->can($self->_func_impl);
+        $self->_func_name($self->_func_impl);
+        my $f = $self->_func_impl;
+        $self->_func_impl(sub { return AI::MXNet::NDArray->$f(@_) });
+    }
+}
+
+method forward(@args)
+{
+    return $self->_func_impl->(@args);
+}
+
+use overload '""' => sub {
+    my $self = shift;
+    return "${\ $self->_class_name }(${\ $self->_func_name })";
+};
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
+package AI::MXNet::Gluon::NN::HybridLambda;
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+
+=head1 NAME
+
+    AI::MXNet::Gluon::NN::HybridLambda - Wraps an operator or an expression as a HybridBlock object.
+=cut
+
+=head1 DESCRIPTION
+
+    Wraps an operator or an expression as a HybridBlock object.
+
+    Parameters
+    ----------
+    function : str or sub
+        Function used in lambda must be one of the following:
+        1) the name of an operator that is available in symbol and ndarray. For example
+
+            $block = nn->Lambda('tanh')
+
+        2) a sub. For example
+
+            $block = nn->Lambda(sub { my $F = shift; $F->LeakyReLU($x, slope=>0.1) });
+=cut
+
+has '_func_impl' => (is => 'rw', init_arg => 'function', isa => 'Str|CodeRef', required => 1);
+has '_func_name' => (is => 'rw', init_arg => undef, default => 'custom_sub');
+method python_constructor_arguments() { ['function'] }
+
+sub BUILD
+{
+    my $self = shift;
+    if(not ref $self->_func_impl)
+    {
+        confess("Function name ${\ $self->_func_impl } is not found in ndarray.")
+            unless AI::MXNet::NDArray->can($self->_func_impl) or AI::MXNet::Symbol->can($self->_func_impl);
+        $self->_func_name($self->_func_impl);
+        my $f = $self->_func_impl;
+        $self->_func_impl(sub { my $F = shift; return $F->$f(@_) });
+    }
+}
+
+method hybrid_forward(@args)
+{
+    return $self->_func_impl->(@args);
+}
+
+use overload '""' => sub {
+    my $self = shift;
+    return "${\ $self->_class_name }(${\ $self->_func_name })";
+};
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
 1;
\ No newline at end of file
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/ConvLayers.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/ConvLayers.pm
index 502f522..a4bb89b 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/ConvLayers.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/NN/ConvLayers.pm
@@ -835,6 +835,7 @@ has 'global_pool' => (is => 'rw', isa => 'Bool', default => 0);
 has 'kwargs'      => (is => 'rw', init_arg => undef);
 has 'pool_type'   => (is => 'rw', isa => 'PoolType');
 has 'layout'      => (is => 'rw');
+has 'count_include_pad' => (is => 'rw', isa => 'Bool');
 method python_constructor_arguments() { [qw/pool_size strides padding/] }
 
 sub BUILD
@@ -856,7 +857,8 @@ sub BUILD
     $self->kwargs({
         kernel => $self->pool_size, stride => $self->strides, pad => $self->padding,
         global_pool => $self->global_pool, pool_type => $self->pool_type,
-        pooling_convention => $self->ceil_mode ? 'full' : 'valid'
+        pooling_convention => $self->ceil_mode ? 'full' : 'valid',
+        (defined $self->count_include_pad ? (count_include_pad => $self->count_include_pad) : ())
     });
 }
 
@@ -1116,6 +1118,8 @@ extends 'AI::MXNet::Gluon::NN::MaxPool1D';
         respectively. padding is applied on 'W' dimension.
     ceil_mode : bool, default False
         When `True`, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Input shape:
@@ -1135,6 +1139,7 @@ extends 'AI::MXNet::Gluon::NN::MaxPool1D';
 =cut
 
 has '+pool_type' => (default => 'avg');
+has '+count_include_pad' => (default => 1);
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
@@ -1167,6 +1172,8 @@ extends 'AI::MXNet::Gluon::NN::MaxPool2D';
         dimensions respectively. padding is applied on 'H' and 'W' dimension.
     ceil_mode : bool, default False
         When True, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Input shape:
@@ -1187,6 +1194,7 @@ extends 'AI::MXNet::Gluon::NN::MaxPool2D';
 =cut
 
 has '+pool_type' => (default => 'avg');
+has '+count_include_pad' => (default => 1);
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
@@ -1220,6 +1228,8 @@ extends 'AI::MXNet::Gluon::NN::MaxPool3D';
         dimension.
     ceil_mode : bool, default False
         When True, will use ceil instead of floor to compute the output shape.
+    count_include_pad : bool, default True
+        When 'False', will exclude padding elements when computing the average value.
 
 
     Input shape:
@@ -1242,6 +1252,8 @@ extends 'AI::MXNet::Gluon::NN::MaxPool3D';
 =cut
 
 has '+pool_type' => (default => 'avg');
+has '+count_include_pad' => (default => 1);
+
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
 package AI::MXNet::Gluon::NN::GlobalMaxPool1D;
@@ -1362,4 +1374,45 @@ has '+ceil_mode'   => (default => 1);
 
 __PACKAGE__->register('AI::MXNet::Gluon::NN');
 
+package AI::MXNet::Gluon::NN::ReflectionPad2D;
+use AI::MXNet::Gluon::Mouse;
+extends 'AI::MXNet::Gluon::HybridBlock';
+
+=head1 NAME
+
+    AI::MXNet::Gluon::NN::ReflectionPad2D
+=cut
+
+=head1 DESCRIPTION
+
+    Pads the input tensor using the reflection of the input boundary.
+
+    Parameters
+    ----------
+    padding: int
+        An integer padding size
+
+    Examples
+    --------
+    >>> $m = nn->ReflectionPad2D(3);
+    >>> $input = mx->nd->random->normal(shape=>[16, 3, 224, 224]);
+    >>> $output = $m->($input);
+=cut
+
+has 'padding' => (is => 'rw', isa => 'Int|ArrayRef[Int]', default => 0);
+method python_constructor_arguments() { ['padding'] }
+sub BUILD
+{
+    my $self = shift;
+    $self->padding([(0)x4, ($self->padding)x4]) if not ref $self->padding;
+    confess("pading must be 8 integer long") unless @{ $self->padding } == 8;
+}
+
+method hybrid_forward(GluonClass $F, GluonInput $x)
+{
+    return $F->pad($x, mode=>'reflect', pad_width=>$self->padding);
+}
+
+__PACKAGE__->register('AI::MXNet::Gluon::NN');
+
 1;
\ No newline at end of file
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Parameter.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Parameter.pm
index 131a6ab..c39d5d4 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Parameter.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Parameter.pm
@@ -53,7 +53,7 @@ use AI::MXNet::Function::Parameters;
           iteration when using this option.
         - 'null' means gradient is not requested for this parameter. gradient arrays
           will not be allocated.
-    shape : array ref of int, default None
+    shape : array ref of int or int, default undef
         Shape of this parameter. By default shape is not specified. Parameter with
         unknown shape can be used for `Symbol` API, but `init` will throw an error
         when using `NDArray` API.
@@ -66,6 +66,11 @@ use AI::MXNet::Function::Parameters;
         Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.
     init : Initializer, default None
         Initializer of this parameter. Will use the global initializer by default.
+    stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
+        The storage type of the parameter.
+    grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
+        The storage type of the parameter's gradient.
+
 
     Attributes
     ----------
@@ -80,8 +85,9 @@ use AI::MXNet::Base;
 use overload '""' => sub {
         my $self = shift;
         "Parameter " . $self->name.
-        " (shape=(" . join(',', @{ $self->shape//[] }) .")".
-        ", dtype=" . $self->dtype.")"
+        " (shape=(" . join(', ', @{ $self->shape//[] }) .")".
+        ", dtype=" . $self->dtype.
+        ", stype=" . $self->stype.")"
     },
     fallback => 1;
 
@@ -103,19 +109,22 @@ sub BUILD
 {
     my $self = shift;
     $self->grad_req($self->_grad_req);
+    $self->_shape([$self->_shape]) if defined $self->_shape and not ref $self->_shape;
     $self->_deferred_init([]);
 }
 
 has 'name'                => (is => 'ro', isa => 'Str', required => 1);
 has '_grad_req'           => (is => 'rw', isa => 'GradReq', init_arg => 'grad_req', default => 'write');
-has 'shape'               => (is => 'rw', isa => 'Shape');
+has '_shape'              => (is => 'rw', isa => 'Maybe[Shape|Int]', init_arg => 'shape');
 has 'dtype'               => (is => 'rw', isa => 'Dtype', default => 'float32');
+has ['stype',
+     'grad_stype']        => (is => 'rw', isa => 'Stype', default => 'default');
 has [qw/lr_mult wd_mult/] => (is => 'rw', isa => 'Num', default => 1);
 has 'init'                => (is => 'rw', isa => 'Maybe[Initializer]');
 has 'allow_deferred_init' => (is => 'rw', isa => 'Bool', default => 0);
 has 'differentiable'      => (is => 'rw', isa => 'Bool', default => 1);
 has [qw/_var _data _grad
-    _deferred_init
+    _deferred_init _trainer
     _ctx_list _ctx_map/]  => (is => 'rw', init_arg => undef);
 
 method grad_req(Maybe[GradReq] $req=)
@@ -138,6 +147,64 @@ method grad_req(Maybe[GradReq] $req=)
     }
 }
 
+method shape(@args)
+{
+    return $self->_shape unless @args;
+    if(not defined $args[0])
+    {
+        $self->_shape(undef);
+        return undef;
+    }
+    if(not defined $self->_shape and defined $args[0])
+    {
+        $self->_shape(ref $args[0] ? $args[0] : [$args[0]]);
+        return $self->_shape;
+    }
+    my $new_shape = ref $args[0] ? $args[0] : [$args[0]];
+    my $shape_validated = 0;
+    if(@{ $self->_shape } == @{ $new_shape })
+    {
+        $shape_validated = 1;
+        zip(sub {
+            my ($i, $j) = @_;
+            return unless $i;
+            return if $i == $j;
+            $shape_validated = 0;
+        }, $self->_shape, $new_shape);
+    }
+    assert($shape_validated, 'Expected shape is incompatible with given shape');
+    $self->_shape($new_shape);
+    return $self->_shape;
+}
+
+method _set_trainer($trainer)
+{
+    if($self->stype ne 'default' and $self->_trainer and $trainer and Scalar::Util::refaddr($self->_trainer) ne Scalar::Util::refaddr($trainer))
+    {
+        confess(
+            "Failed to set the trainer for Parameter '${\ $self->name }' because it was already set. ".
+            "More than one trainers for a ${\ $self->stype } Parameter is not supported."
+        );
+    }
+    $self->_trainer($trainer);
+}
+
+method _get_row_sparse($arr_list, $ctx, AI::MXNet::NDArray $row_id)
+{
+    if(not $self->_trainer)
+    {
+        confess(
+            "Cannot get row_sparse data for Parameter '${\ $self->name }' when no ".
+            "Trainer is created with it."
+        );
+    }
+    my $results = $self->_check_and_get($arr_list, $ctx);
+
+    # fetch row sparse params from the trainer
+    $self->_trainer->_row_sparse_pull($self, $results, $row_id);
+    return $results;
+}
+
 method _check_and_get($arr_list, $ctx)
 {
     if(defined $arr_list)
@@ -157,31 +224,31 @@ method _check_and_get($arr_list, $ctx)
                 $ctx = AI::MXNet::Context->current_ctx;
             }
         }
-        my $idx;
-        if(ref $self->_ctx_map->[$ctx->device_type_id])
-        {
-            $idx = $self->_ctx_map->[$ctx->device_type_id][$ctx->device_id];
-        }
-        if(defined $idx)
+        my $ctx_list = $self->_ctx_map->[$ctx->device_type_id&1];
+        if($ctx->device_id < @{ $ctx_list })
         {
-            return $arr_list->[$idx];
+            my $idx = $ctx_list->[$ctx->device_id];
+            if(defined $idx)
+            {
+                return $arr_list->[$idx];
+            }
         }
         confess(
-            "Parameter ${\ $self->name } was not initialized on context $ctx. ".
+            "Parameter '${\ $self->name }' was not initialized on context $ctx. ".
             "It was only initialized on @{ $self->_ctx_list }."
         );
     }
     if(@{ $self->_deferred_init })
     {
         confess("DeferredInitializationError: ".
-            "Parameter ${\ $self->name } has not been initialized yet because initialization was ".
+            "Parameter '${\ $self->name }' has not been initialized yet because initialization was ".
             "deferred. Actual initialization happens during the first forward pass. ".
             "Please pass one batch of data through the network before accessing Parameters. ".
             "You can also avoid deferred initialization by specifying in_units, ".
             "num_features, etc., for network layers.");
     }
     confess(
-        "Parameter ${\ $self->name } has not been initialized. Note that ".
+        "Parameter '${\ $self->name }' has not been initialized. Note that ".
         "you should initialize parameters and create Trainer ".
         "with Block.collect_params() instead of Block.params ".
         "because the later does not include Parameters of ".
@@ -189,34 +256,41 @@ method _check_and_get($arr_list, $ctx)
     );
 }
 
-# (Re)initializes by loading from data.
+
+# (Re)initializes by loading from data. 
 method _load_init($data, $ctx)
 {
     if($self->shape)
     {
         for(zip($self->shape, $data->shape)) {
-            my ($i, $j) = @$_;
+            my ($self_dim, $data_dim) = @$_;
             assert(
-                ($i == 0 or $i == $j),
+                ($self_dim == 0 or $self_dim == $data_dim),
                 sprintf(
-                    "Failed loading Parameter %s from saved params: ".
-                    "shape incompatible expacted (%s) vs saved (%s)",
+                    "Failed loading Parameter '%s' from saved params: ".
+                    "shape incompatible expected (%s) vs saved (%s)",
                     $self->name, "@{$self->shape}", "@{$data->shape}"
                 )
             );
         }
+        $self->shape([map { $_->[0] ? $_->[0] : $_->[1] } zip($self->shape, $data->shape)]);
     }
     if($self->dtype)
     {
         assert(
             ($self->dtype eq $data->dtype),
             sprintf(
-                "Failed loading Parameter %s from saved params: ".
-                "dtype incompatible expacted %s vs saved %s",
+                "Failed loading Parameter '%s' from saved params: ".
+                "dtype incompatible expected %s vs saved %s",
                 $self->name, $self->dtype, $data->dtype
             )
         );
     }
+    if($self->stype ne $data->stype)
+    {
+        $data = $data->tostype($self->stype);
+    }
+
     if(blessed ($ctx) and $ctx->isa('AI::MXNet::Context'))
     {
         $ctx = [$ctx];
@@ -226,23 +300,28 @@ method _load_init($data, $ctx)
         if(@{ $self->_deferred_init })
         {
             assert(
-                ($ctx eq $self->_deferred_init->[1]),
+                (not defined $ctx or join('', @{ $ctx }) eq join('', @{ $self->_deferred_init->[1] })),
                 sprintf(
-                    "Failed to load Parameter %s on %s because it was ".
-                    "previous initialized on %s.",
+                    "Failed to load Parameter '%s' on %s because it was ".
+                    "previously initialized on %s.",
                     $self->name, $ctx, $self->list_ctx
                 )
             );
+            $ctx = $self->_deferred_init->[1];
+        }
+        elsif(not defined $ctx)
+        {
+            $ctx = [AI::MXNet::Context->cpu];
         }
         $self->_init_impl($data, $ctx);
     }
     else
     {
         assert(
-            (join('', @{ $ctx }) eq join('', @{ $self->list_ctx })),
+            (not defined $ctx or join('', @{ $ctx }) eq join('', @{ $self->list_ctx })),
             sprintf(
-                "Failed to load Parameter %s on %s because it was ".
-                "previous initialized on %s.",
+                "Failed to load Parameter '%s' on %s because it was ".
+                "previously initialized on %s.",
                 $self->name, "@$ctx", "@{$self->list_ctx}"
             )
         );
@@ -255,28 +334,34 @@ method _load_init($data, $ctx)
 method _finish_deferred_init()
 {
     return unless @{ $self->_deferred_init };
-    my ($init, $ctx, $default_init) = @{ $self->_deferred_init };
+    my ($init, $ctx, $default_init, $data) = @{ $self->_deferred_init };
     $self->_deferred_init([]);
     assert(
         (defined($self->shape) and product(@{ $self->shape }) > 0),
         sprintf(
-            "Cannot initialize Parameter %s because it has ".
+            "Cannot initialize Parameter '%s' because it has ".
             "invalid shape: %s. Please specify in_units, ".
             "in_channels, etc for `Block`s.",
             $self->name, $self->shape
         )
     );
     AI::MXNet::AutoGrad->pause(sub {
-        my $data = AI::MXNet::NDArray->zeros(
-            $self->shape, dtype => $self->dtype, ctx => AI::MXNet::Context->cpu
-        );
-        AI::MXNet::Initializer->new->(
-            AI::MXNet::InitDesc->new(
-                name => $self->name,
-                attrs => { __init__ => defined $init ? "$init" : "$default_init" }
-            ),
-            $data
-        );
+        if(not defined $data)
+        {
+            $data = AI::MXNet::NDArray->zeros(
+                $self->shape,
+                dtype => $self->dtype,
+                ctx => AI::MXNet::Context->cpu,
+                stype => $self->stype
+            );
+            AI::MXNet::Initializer->new->(
+                AI::MXNet::InitDesc->new(
+                    name => $self->name,
+                    attrs => { __init__ => defined $init ? "$init" : "$default_init" }
+                ),
+                $data
+            );
+        }
         $self->_init_impl($data, $ctx);
     });
 }
@@ -285,14 +370,10 @@ method _finish_deferred_init()
 method _init_impl($data, $ctx_list)
 {
     $self->_ctx_list([@{ $ctx_list }]);
-    $self->_ctx_map([]);
+    $self->_ctx_map([[], []]);
     enumerate(sub {
         my ($i, $ctx) = @_;
-        while(@{ $self->_ctx_map } <= $ctx->device_type_id)
-        {
-            push @{ $self->_ctx_map }, [];
-        }
-        my $dev_list = $self->_ctx_map->[$ctx->device_type_id];
+        my $dev_list = $self->_ctx_map->[$ctx->device_type_id&1];
         while(@{ $dev_list } <= $ctx->device_id)
         {
             push @{ $dev_list }, undef;
@@ -311,20 +392,40 @@ method _init_grad()
         $self->_grad(undef);
         return;
     }
-    $self->_grad([map { AI::MXNet::NDArray->zeros_like($_) } @{ $self->_data }]);
-    AI::MXNet::AutoGrad->mark_variables($self->list_data, $self->list_grad, grad_reqs => $self->grad_req);
+    $self->_grad([
+        map {
+            AI::MXNet::NDArray->zeros(
+                $_->shape, dtype => $_->dtype,
+                ctx => $_->context, stype => $self->grad_stype
+            )
+        } @{ $self->_data }
+    ]);
+    AI::MXNet::AutoGrad->mark_variables(
+        $self->_check_and_get($self->_data, []),
+        $self->_grad,
+        grad_reqs => $self->grad_req
+    );
 }
 
-# Reduce data from multiple context.
-
+# Reduce data from multiple contexts to cpu.
 method _reduce()
 {
-    my $block = $self->list_data;
-    my $data = AI::MXNet::NDArray->add_n(map { $_->copyto(AI::MXNet::Context->cpu) } @{ $block }) / @{ $block };
+    my $data;
+    my $ctx = AI::MXNet::Context->cpu;
+    if($self->stype eq 'default')
+    {
+        my $block = $self->list_data;
+        $data = AI::MXNet::NDArray->add_n(map { $_->copyto($ctx) } @{ $block }) / @{ $block };
+    }
+    else
+    {
+        my $all_row_ids = AI::MXNet::NDArray->arange(stop => $self->shape->[0], dtype=>'int64', ctx=>$ctx);
+        $data = AI::MXNet::NDArray->zeros($self->shape, stype=>'row_sparse', ctx=>$ctx);
+        $self->_trainer->_row_sparse_pull($self, $data, $all_row_ids);
+    }
     return $data;
 }
 
-
 =head2 initialize
 
         Initializes parameter and gradient arrays. Only used for `NDArray` API.
@@ -377,7 +478,7 @@ method initialize(
     if(defined $self->_data and not $force_reinit)
     {
         AI::MXNet::Logging->warning(
-            "Parameter %s is already initialized, ignoring. ".
+            "Parameter '%s' is already initialized, ignoring. ".
             "Set force_reinit=True to re-initialize.", $self->name
         );
         return;
@@ -403,13 +504,13 @@ method initialize(
     {
         if($self->allow_deferred_init)
         {
-            $self->_deferred_init([$init, $ctx, $default_init]);
+            $self->_deferred_init([$init, $ctx, $default_init, undef]);
             return;
         }
-        confess("Cannot initialize Parameter ${\ $self->name } because it has ".
+        confess("Cannot initialize Parameter '${\ $self->name }' because it has ".
                 "invalid shape: @{$self->shape//[]}.");
     }
-    $self->_deferred_init([$init, $ctx, $default_init]);
+    $self->_deferred_init([$init, $ctx, $default_init, undef]);
     $self->_finish_deferred_init;
 }
 
@@ -437,12 +538,12 @@ method reset_ctx(Maybe[AI::MXNet::Context|ArrayRef[AI::MXNet::Context]] :$ctx=AI
     }
     elsif(@{ $self->_deferred_init })
     {
-        my ($init, undef, $default_init) = @{ $self->_deferred_init };
-        $self->_deferred_init([$init, $ctx, $default_init]);
+        my ($init, undef, $default_init, $data) = @{ $self->_deferred_init };
+        $self->_deferred_init([$init, $ctx, $default_init, $data]);
     }
     else
     {
-        confess("Cannot reset context for Parameter ${ \ $self->name } because it ".
+        confess("Cannot reset context for Parameter '${ \ $self->name }' because it ".
                 "has not been initialized.");
     }
 }
@@ -454,20 +555,93 @@ method reset_ctx(Maybe[AI::MXNet::Context|ArrayRef[AI::MXNet::Context]] :$ctx=AI
 
 method set_data($data)
 {
-    assert(
-        (defined $self->_data),
-        "Parameter ${\ $self->name } has not been initialized"
-    );
-    for my $arr (@{ $self->list_data })
+    $self->shape($data->shape);
+    if(not defined $self->_data)
+    {
+        assert(
+            (@{ $self->_deferred_init }),
+            "Parameter '${\ $self->name }' has not been initialized"
+        );
+        $self->_deferred_init->[3] = $data;
+        return;
+    }
+
+    # if update_on_kvstore, we need to make sure the copy stored in kvstore is in sync
+    if($self->_trainer and $self->_trainer->_kv_initialized and $self->_trainer->update_on_kvstore)
+    {
+        if(!grep { Scalar::Util::refaddr($self) == Scalar::Util::refaddr($_) } @{ $self->_trainer->_params_to_init })
+        {
+            $self->_trainer->_reset_kvstore();
+        }
+    }
+    for my $arr (@{ $self->_check_and_get($self->_data, []) })
     {
         $arr .= $data;
     }
 }
 
+=head2 row_sparse_data 
+
+        Returns a copy of the 'row_sparse' parameter on the same context as row_id's.
+        The copy only retains rows whose ids occur in provided row ids.
+        The parameter must have been initialized on this context before.
+
+        Parameters
+        ----------
+        $row_id: AI::MXNet::NDArray
+            Row ids to retain for the 'row_sparse' parameter.
+
+        Returns
+        -------
+        AI::MXNet::NDArray on row_id's context
+=cut
+
+method row_sparse_data(AI::MXNet::NDArray $row_id)
+{
+    if($self->stype ne 'row_sparse')
+    {
+        confess(
+            "Cannot return a copy of Parameter ${\ $self->name } via row_sparse_data() ".
+            "because its storage type is ${\ $self->stype }. Please use data() instead."
+        );
+    }
+    return $self->_get_row_sparse($self->_data, $row_id->context, $row_id);
+}
+
+=head2 list_row_sparse_data
+
+        Returns copies of the 'row_sparse' parameter on all contexts, in the same order
+        as creation. The copy only retains rows whose ids occur in provided row ids.
+        The parameter must have been initialized before.
+
+        Parameters
+        ----------
+        $row_id: AI::MXNet::NDArray
+            Row ids to retain for the 'row_sparse' parameter.
+
+        Returns
+        -------
+        array ref of AI::MXNet::NDArrays
+=cut
+
+method list_row_sparse_data(AI::MXNet::NDArray $row_id)
+{
+    if($self->stype ne 'row_sparse')
+    {
+        confess(
+            "Cannot return copies of Parameter '${\ $self->name }' on all contexts via ".
+            "list_row_sparse_data() because its storage type is ${\ $self->stype }. Please ".
+            "use data() instead."
+        );
+    }
+    return $self->_get_row_sparse($self->_data, [], $row_id);
+}
+
 =head2 data
 
         Returns a copy of this parameter on one context. Must have been
-        initialized on this context before.
+        initialized on this context before. For sparse parameters, use
+        row_sparse_data instead.
 
         Parameters
         ----------
@@ -481,17 +655,35 @@ method set_data($data)
 
 method data(Maybe[AI::MXNet::Context] $ctx=)
 {
+    if($self->stype ne 'default')
+    {
+        $ctx //= AI::MXNet::Context->current_ctx;
+        confess(
+            "Cannot return a copy of Parameter '${\ $self->name }' on ctx $ctx via data() ".
+            "because its storage type is ${\ $self->stype }. Please use row_sparse_data() ".
+            "instead."
+        );
+    }
     return $self->_check_and_get($self->_data, $ctx);
 }
 
 =head2 list_data
 
         Returns copies of this parameter on all contexts, in the same order
-        as creation.
+        as creation. For sparse parameters, use list_row_sparse_data
+        instead.
 =cut
 
 method list_data()
 {
+    if($self->stype ne 'default')
+    {
+        confess(
+            "Cannot return a copies of Parameter '${\ $self->data }' on all contexts via list_data() ".
+            "because its storage type is ${\ $self->stype }. Please use row_sparse_data() ".
+            "instead."
+        );
+    }
     return $self->_check_and_get($self->_data, [])
 }
 
@@ -562,7 +754,7 @@ method list_ctx()
 method zero_grad()
 {
     return unless defined $self->_grad;
-    map { $_ .= 0 } @{ $self->_grad };
+    AI::MXNet::NDArray->zeros_like($_, { out => $_ }) for @{ $self->_grad };
 }
 
 =head2 var
@@ -578,13 +770,129 @@ method var()
             AI::MXNet::Symbol->var(
                 $self->name, shape => $self->shape, dtype => $self->dtype,
                 lr_mult => $self->lr_mult, wd_mult => $self->wd_mult,
-                init => $self->init
+                init => $self->init, stype => $self->stype
             )
         );
     }
     return $self->_var;
 }
 
+=head2 cast
+
+    Cast data and gradient of this Parameter to a new data type.
+
+    Parameters
+     ----------
+    $dtype : Dtype
+    The new data type.
+=cut
+
+method cast(Dtype $dtype)
+{
+    $self->dtype($dtype);
+    return unless defined $self->_data;
+    AI::MXNet::AutoGrad->pause(sub {
+        $self->_data([map { $_->astype($dtype) } @{ $self->_data }]);
+        return unless defined $self->_grad;
+        $self->_grad([map { $_->astype($dtype) } @{ $self->_grad }]);
+        AI::MXNet::AutoGrad->mark_variables($self->_data, $self->_grad, grad_reqs => $self->grad_req);
+    });
+}
+
+package AI::MXNet::Gluon::Constant;
+use strict;
+use warnings;
+use Mouse;
+extends 'AI::MXNet::Gluon::Parameter';
+
+=head1 NAME 
+
+    AI::MXNet::Gluon::Constant - A constant parameter for holding immutable tensors.
+=cut
+
+=head1 DESCRIPTION
+
+    A constant parameter for holding immutable tensors.
+    Constants are ignored by autograd and Trainer, thus their values
+    will not change during training. But you can still update their values
+    manually with the set_data method.
+
+    Constants can be created with either
+
+        $const = mx->gluon->Constant('const', [[1,2],[3,4]]);
+
+    or
+
+        package Block;
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::Block';
+        sub BUILD
+        {
+            $self->const($self->params->get_constant('const', [[1,2],[3,4]]));
+        }
+
+    Constructor Attributes
+    ----------
+    name : str
+        Name of the parameter.
+    value : AcceptableInput (perl array, pdl, ndarray, etc)
+        Initial value for the constant.
+=cut
+
+use Mouse;
+use AI::MXNet::Base;
+use Scalar::Util qw(refaddr);
+around BUILDARGS => \&AI::MXNet::Base::process_arguments;
+method python_constructor_arguments() { ['name', 'value'] }
+has 'value'     => (is => 'rw', isa => 'AcceptableInput');
+has '+_grad_req' => (is => 'rw', default => 'null');
+use overload '""' => sub {
+        my $self = shift;
+        "Constant " . $self->name.
+        " (shape=(" . join(', ', @{ $self->shape//[] }) .")".
+        ", dtype=" . $self->dtype.
+        ", stype=" . $self->stype.")"
+    },
+    fallback => 1;
+
+
+sub BUILD
+{
+    my $self = shift;
+    if(not (blessed $self->value and $self->value->isa('AI::MXNet::NDArray')))
+    {
+        $self->value(AI::MXNet::NDArray->array($self->value, dtype => $self->dtype));
+    }
+    $self->shape($self->value->shape);
+    my $init = "AI::MXNet::Gluon::Constant::Init_${\ $self->name }_${\ refaddr($self) }";
+    my $tmp =<<"EOP";
+    package $init;
+    use Mouse;
+    extends 'AI::MXNet::Initializer';
+    sub _init_weight
+    {
+        \$self->value->copyto(\$_[2]);
+    }
+    $init->register;
+    1;
+EOP
+    eval $tmp;
+    $self->init($init->new);
+}
+
+method grad_req($req=)
+{
+    if(defined $req and $req ne 'null')
+    {
+        AI::MXNet::Logging->warning(
+            'Constant parameter "%s" does not support '.
+            'grad_req other than "null", and new value "%s" '.
+            'is ignored.',
+            $self->name, $req
+        );
+    }
+    return 'null';
+}
 
 package AI::MXNet::Gluon::ParameterDict;
 use AI::MXNet::Base;
@@ -650,6 +958,10 @@ method prefix()
     $self->_prefix;
 }
 
+method params()
+{
+    $self->_params;
+}
 
 method _get_impl($name)
 {
@@ -702,12 +1014,51 @@ method get(Str $name, %kwargs)
         {
             if($param->can($k))
             {
-                assert(
-                    (not defined $v or Dumper($v) eq Dumper($param->$k)),
-                    "Cannot retrieve Parameter $name because desired attribute ".
-                    "does not match with stored for attribute $k: ".
-                    "desired ".Dumper($v)." vs stored ". Dumper($param->$k)
-                );
+                if(defined $param->$k)
+                {
+                    my $existing = $param->$k;
+                    if($k eq 'shape' and @{$v} == @{$existing})
+                    {
+                        my @inferred_shape;
+                        my $matched = 1;
+                        for(zip($v, $existing))
+                        {
+                            my ($dim1, $dim2) = @$_;
+                            if($dim1 != $dim2 and $dim1 * $dim2 != 0)
+                            {
+                                $matched = 0;
+                                 last;
+                            }
+                            elsif($dim1 == $dim2)
+                            {
+                                push @inferred_shape, $dim1;
+                            }
+                            elsif($dim1 == 0)
+                            {
+                                push @inferred_shape, $dim2;
+                            }
+                            else
+                            {
+                                push @inferred_shape, $dim1;
+                            }
+                        }
+                        if($matched)
+                        {
+                            $param->_shape(\@inferred_shape);
+                            next;
+                        }
+                    }
+                    assert(
+                        (not defined $v or Dumper($v) eq Dumper($param->$k)),
+                        "Cannot retrieve Parameter $name because desired attribute ".
+                        "does not match with stored for attribute $k: ".
+                        "desired ".Dumper($v)." vs stored ". Dumper($param->$k)
+                    );
+                }
+                else
+                {
+                    $param->$k($v);
+                }
             }
             else
             {
@@ -723,10 +1074,10 @@ method get(Str $name, %kwargs)
     Copies all Parameters in $other to self.
 =cut
 
-method update($other)
+method update($other, Maybe[Str] $select=)
 {
     my @keys = $other->keys;
-    for my $k (@keys)
+    for my $k (grep { not defined $select or /$select/ } @keys)
     {
         if($self->_params->exists($k))
         {
@@ -743,6 +1094,50 @@ method update($other)
     }
 }
 
+=head2 get_constant
+
+        Retrieves AI::MXNet::Gluon::Constant with name $self->prefix.$name. If not found,
+        'get' will first try to retrieve it from "shared" dictionary. If still not
+        found, 'get' will create a new Constant with key-word
+        arguments and insert it to self.
+
+        Parameters
+        ----------
+        name : str
+            Name of the desired Constant. It will be prepended with this dictionary's
+            prefix.
+        value : array-like
+            Initial value of constant.
+
+        Returns
+        -------
+        Constant
+            The created or retrieved Constant.
+=cut
+
+method get_constant(Str $name, Maybe[AcceptableInput] $value=)
+{
+    $name = $self->prefix . $name;
+    my $param = $self->_get_impl($name);
+    if(not defined $param)
+    {
+        if(not defined $value)
+        {
+            confess(
+                "No constant named '$name'. Please specify value ".
+                "if you want to create a new constant."
+            );
+        }
+        $param = AI::MXNet::Gluon::Constant->new($name, $value);
+        $self->_params->set($name, $param);
+    }
+    elsif(defined $value)
+    {
+        confess("reinit of Constant $name is not allowed");
+    }
+    return $param;
+}
+
 =head2 initialize
 
         Initializes all Parameters managed by this dictionary to be used for 'NDArray'
@@ -757,9 +1152,10 @@ method update($other)
             Keeps a copy of Parameters on one or many context(s).
         :$force_reinit : bool, default False
             Whether to force re-initialization if parameter is already initialized.
+        :$verbose : bool, default False
+            Whether to force re-initialization if parameter is already initialized.
 =cut
 
-
 method initialize(
     Initializer                                            :$init=AI::MXNet::Initializer->Uniform(),
     Maybe[AI::MXNet::Context|ArrayRef[AI::MXNet::Context]] :$ctx=,
@@ -873,6 +1269,7 @@ method save(Str $filename, Str $strip_prefix='')
         :$restore_prefix : str, default ''
             prepend prefix to names of stored parameters before loading.
 =cut
+
 method load(
     Str                                              $filename,
     AI::MXNet::Context|ArrayRef[AI::MXNet::Context] :$ctx=AI::MXNet::Context->current_ctx,
@@ -894,7 +1291,7 @@ method load(
     }
     my $lprefix  = length $restore_prefix;
     my %orig_load = %{ AI::MXNet::NDArray->load($filename) };
-    my %arg_dict  = map { ($restore_prefix.$_, $orig_load{$_}) } keys %orig_load;
+    my %arg_dict  = map { my $k = $_; s/^(?:arg|aux)://; ($restore_prefix.$_, $orig_load{$k}) } keys %orig_load;
     if(not $allow_missing)
     {
         for my $name ($self->keys())
@@ -919,7 +1316,7 @@ method load(
             );
             next;
         }
-        @{$self}{$name}->_load_init($arg_dict{$name}, $ctx);
+        $self->{ $name }->_load_init($arg_dict{$name}, $ctx);
     }
 }
 
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN.pm
index 6a51227..cdd9468 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN.pm
@@ -31,7 +31,7 @@ sub import
         {
             my $short_name_package =<<"EOP";
             package $short_name;
-            \@${short_name}::ISA = ('AI::MXNet::Gluon::RNN_');;
+            \@${short_name}::ISA = ('AI::MXNet::Gluon::RNN_');
             1;
 EOP
             eval $short_name_package;
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN/Cell.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN/Cell.pm
index a3fb3c5..c14b792 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN/Cell.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN/Cell.pm
@@ -24,12 +24,12 @@ use AI::MXNet::Function::Parameters;
 
 method _cells_state_info($cells, $batch_size)
 {
-    return [map { @{ $_->state_info($batch_size) } } @{ $cells }];
+    return [map { @{ $_->state_info($batch_size) } } $cells->values];
 }
 
 method _cells_begin_state($cells, %kwargs)
 {
-    return [map { @{ $_->begin_state(%kwargs) } } @{ $cells }];
+    return [map { @{ $_->begin_state(%kwargs) } } $cells->values];
 }
 
 method _get_begin_state(GluonClass $F, $begin_state, GluonInput $inputs, $batch_size)
@@ -158,7 +158,7 @@ method reset()
 {
     $self->init_counter(-1);
     $self->counter(-1);
-    $_->reset for @{ $self->_children };
+    $_->reset for $self->_children->values;
 }
 
 =head2 state_info
@@ -290,7 +290,6 @@ method unroll(
 
     my $states = $begin_state;
     my $outputs = [];
-    use Data::Dumper;
     for my $i (0..$length-1)
     {
         my $output;
@@ -805,7 +804,7 @@ method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=,
     $self->reset();
     my ($F, $batch_size);
     ($inputs, undef, $F, $batch_size) = $self->_format_sequence($length, $inputs, $layout, undef);
-    my $num_cells = @{ $self->_children };
+    my $num_cells = $self->_children->keys;
     $begin_state = $self->_get_begin_state($F, $begin_state, $inputs, $batch_size);
     my $p = 0;
     my @next_states;
@@ -820,7 +819,7 @@ method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=,
             merge_outputs => ($i < ($num_cells - 1)) ? undef : $merge_outputs
         );
         push @next_states, @{ $states };
-    }, $self->_children);
+    }, [$self->_children->values]);
     return ($inputs, \@next_states);
 }
 
@@ -829,7 +828,7 @@ method call($inputs, $states)
     $self->counter($self->counter + 1);
     my @next_states;
     my $p = 0;
-    for my $cell (@{ $self->_children })
+    for my $cell ($self->_children->values)
     {
         assert(not $cell->isa('AI::MXNet::Gluon::RNN::BidirectionalCell'));
         my $n = @{ $cell->state_info() };
@@ -841,7 +840,7 @@ method call($inputs, $states)
     return ($inputs, \@next_states);
 }
 
-use overload '@{}' => sub { shift->_children };
+use overload '@{}' => sub { [shift->_children->values] };
 use overload '""'  => sub {
     my $self = shift;
     my $s = "%s(\n%s\n)";
@@ -849,7 +848,7 @@ use overload '""'  => sub {
     enumerate(sub {
         my ($i, $m) = @_;
         push @children, "($i): ". AI::MXNet::Base::_indent("$m", 2);
-    }, $self->_children);
+    }, [$self->_children->values]);
     return sprintf($s, $self->_class_name, join("\n", @children));
 };
 
@@ -1178,7 +1177,7 @@ method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=,
     $begin_state //= $self->_get_begin_state($F, $begin_state, $inputs, $batch_size);
 
     my $states = $begin_state;
-    my ($l_cell, $r_cell) = @{ $self->_children };
+    my ($l_cell, $r_cell) = $self->_children->values;
     $l_cell->state_info($batch_size);
     my ($l_outputs, $l_states) = $l_cell->unroll(
             $length, $inputs,
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Trainer.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Trainer.pm
index c2e8f31..1b3b49f 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Trainer.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Trainer.pm
@@ -23,6 +23,7 @@ use AI::MXNet::Function::Parameters;
 use IO::File;
 use Mouse;
 
+
 =head1 NAME
 
     AI::MXNet::Gluon::Trainer
@@ -35,53 +36,78 @@ use Mouse;
 
     Parameters
     ----------
-    params : ParameterDict
+    params : AI::MXNet::Gluon::ParameterDict
         The set of parameters to optimize.
     optimizer : str or Optimizer
         The optimizer to use. See
         `help <http://mxnet.io/api/python/optimization/optimization.html#the-mxnet-optimizer-package>`_
         on Optimizer for a list of available optimizers.
-    optimizer_params : dict
+    optimizer_params : hash ref
         Key-word arguments to be passed to optimizer constructor. For example,
-        `{'learning_rate': 0.1}`. All optimizers accept learning_rate, wd (weight decay),
+        {learning_rate => 0.1}. All optimizers accept learning_rate, wd (weight decay),
         clip_gradient, and lr_scheduler. See each optimizer's
         constructor for a list of additional supported arguments.
     kvstore : str or KVStore
         kvstore type for multi-gpu and distributed training. See help on
-        :any:`mxnet.kvstore.create` for more information.
+        mx->kvstore->create for more information.
+    compression_params : hash ref
+        Specifies type of gradient compression and additional arguments depending
+        on the type of compression being used. For example, 2bit compression requires a threshold.
+        Arguments would then be {type => '2bit', threshold => 0.5}
+        See AI::MXNet::KVStore->set_gradient_compression method for more details on gradient compression.
+    update_on_kvstore : Bool, default undef
+        Whether to perform parameter updates on kvstore. If undef, then trainer will choose the more
+        suitable option depending on the type of kvstore.
+
+    Properties
+    ----------
+    learning_rate : float
+        The current learning rate of the optimizer. Given an Optimizer object
+        optimizer, its learning rate can be accessed as optimizer->learning_rate.
 =cut
 
-has '_params'          => (is => 'rw', init_arg => 'params', isa => 'HashRef|ArrayRef|AI::MXNet::Gluon::ParameterDict');
-has 'optimizer'        => (is => 'ro', isa => 'Optimizer');
-has 'optimizer_params' => (is => 'ro', isa => 'Maybe[HashRef]');
-has '_kv_store'        => (is => 'rw', init_arg => 'kvstore', isa => 'Maybe[KVStore]', default => 'device');
+has 'params'             => (is => 'rw', isa => 'HashRef|ArrayRef|AI::MXNet::Gluon::ParameterDict');
+has 'optimizer'          => (is => 'ro', isa => 'Optimizer');
+has 'optimizer_params'   => (is => 'ro', isa => 'Maybe[HashRef]');
+has 'compression_params' => (is => 'ro', isa => 'Maybe[HashRef]');
+has 'kvstore'            => (is => 'rw', isa => 'Maybe[KVStore]', default => 'device');
+has 'update_on_kvstore'  => (is => 'rw', isa => 'Maybe[Bool]');
 has [qw/_scale _contexts
     _kv_initialized
-    _update_on_kvstore
+    _param2idx
+    _kvstore_params
+    _contains_sparse
+    _params_to_init
     _updaters
     _optimizer/]       => (is => 'rw', init_arg => undef);
 around BUILDARGS => \&AI::MXNet::Base::process_arguments;
-method python_constructor_arguments() { ['params', 'optimizer', 'optimizer_params'] }
+method python_constructor_arguments()
+{
+    [qw/params optimizer optimizer_params kvstore compression_params update_on_kvstore/]
+}
 
 sub BUILD
 {
     my $self = shift;
     my @params;
-    if(blessed $self->_params)
+    if(blessed $self->params)
     {
-        @params = $self->_params->values;
+        @params = $self->params->values;
     }
-    elsif(ref $self->_params eq 'HASH')
+    elsif(ref $self->params eq 'HASH')
     {
-        @params = values %{ $self->_params };
+        @params = values %{ $self->params };
     }
     else
     {
-        @params = @{ $self->_params };
+        @params = @{ $self->params };
     }
-    $self->_params([]);
-    for my $param (@params)
+    $self->params([]);
+    $self->_contains_sparse(0);
+    $self->_param2idx({});
+    for(enumerate(\@params))
     {
+        my ($i, $param) = @$_;
         if(not(blessed $param and $param->isa('AI::MXNet::Gluon::Parameter')))
         {
             confess(
@@ -89,19 +115,33 @@ sub BUILD
                 "got list of [$param]."
             );
         }
-        push @{ $self->_params }, $param;
+        $self->_param2idx->{ $param->name } = $i;
+        push @{ $self->params }, $param;
+        $param->_set_trainer($self);
+        if($param->stype ne 'default')
+        {
+            $self->_contains_sparse(1);
+        }
     }
     my $optimizer_params = $self->optimizer_params//{};
     $self->_scale(delete $optimizer_params->{rescale_grad}//1);
     $self->_contexts($self->_check_contexts);
     $self->_init_optimizer($self->optimizer, $optimizer_params);
+    $self->_kvstore_params({
+        kvstore => $self->kvstore,
+        update_on_kvstore => $self->update_on_kvstore
+    });
     $self->_kv_initialized(0);
+    $self->kvstore(undef);
+    $self->update_on_kvstore(undef);
+    $self->_params_to_init([]);
+    $self->_reset_kvstore();
 }
 
 method _check_contexts()
 {
     my $contexts;
-    for my $param (@{ $self->_params })
+    for my $param (@{ $self->params })
     {
         my $ctx = $param->list_ctx;
         assert(
@@ -117,7 +157,7 @@ method _check_contexts()
 
 method _init_optimizer($optimizer, $optimizer_params)
 {
-    my %param_dict = map { $_ => $self->_params->[$_] } 0 .. @{ $self->_params } - 1;
+    my %param_dict = map { $_ => $self->params->[$_] } 0 .. @{ $self->params } - 1;
     if(blessed $optimizer and $optimizer->isa('AI::MXNet::Optimizer'))
     {
         assert(
@@ -142,104 +182,187 @@ method _init_optimizer($optimizer, $optimizer_params)
     ]);
 }
 
-method _init_kvstore()
+method _init_params()
 {
-    my %arg_arrays = map { $_->name => $_->data($self->_contexts->[0]) } @{ $self->_params };
-    my ($kvstore, $update_on_kvstore) = AI::MXNet::Module::_create_kvstore(
-        $self->_kv_store, scalar(@{$self->_contexts }), \%arg_arrays
+    assert(
+        $self->_kv_initialized,
+        "Cannot initialize parameters in KVStore ".
+        "when KVStore is not initialized."
     );
+    my @params_to_init;
+    if($self->kvstore)
+    {
+        for my $param (@{ $self->_params_to_init })
+        {
+            if(@{ $param->_deferred_init })
+            {
+                push @params_to_init, $param;
+            }
+            else
+            {
+                my $param_arrays = $param->_check_and_get($param->_data, []);
+                my $idx = $self->_param2idx->{ $param->name };
+                $self->kvstore->init($idx, $param_arrays->[0]);
+                if($param->stype eq 'default')
+                {
+                    $self->kvstore->pull($idx, out => $param_arrays, priority=>-$idx);
+                }
+            }
+        }
+    }
+    $self->_params_to_init(\@params_to_init);
+}
+
+method _reset_kvstore()
+{
+    if($self->kvstore and $self->kvstore->type =~ /dist/)
+    {
+        confess("Cannot reset distributed KVStore.");
+    }
+    $self->_kv_initialized(0);
+    $self->kvstore(undef);
+    $self->update_on_kvstore(undef);
+    $self->_params_to_init([@{ $self->params }]);
+}
+
+method _init_kvstore()
+{
+    my $config = $self->_kvstore_params;
+    my ($kvstore, $update_on_kvstore);
+    if($self->_contains_sparse)
+    {
+        ($kvstore, $update_on_kvstore) = AI::MXNet::Module::_create_sparse_kvstore($config->{kvstore});
+        # update_on_kvstore is set to False by the user
+        if(defined $config->{update_on_kvstore} and not $config->{update_on_kvstore})
+        {
+            confess(
+                "Cannot set update_on_kvstore to False when sparse ".
+                "gradients and/or sparse weights are present."
+            )
+        }
+    }
+    else
+    {
+        my %arg_arrays = map { $_->name => $_->data($self->_contexts->[0]) } @{ $self->params };
+        ($kvstore, $update_on_kvstore) = AI::MXNet::Module::_create_kvstore(
+            $config->{kvstore}, scalar(@{$self->_contexts }), \%arg_arrays
+        );
+        if(defined $config->{update_on_kvstore})
+        {
+            $update_on_kvstore = $config->{update_on_kvstore};
+        }
+    }
     if($kvstore)
     {
-        if($kvstore->type =~ /dist/)
+        if($self->compression_params)
+        {
+            $kvstore->set_gradient_compression($self->compression_params);
+        }
+        # kv->pull(row_sparse_grad) is not supported
+        if($kvstore->type =~ /dist/ and not $self->_contains_sparse)
         {
             $update_on_kvstore = 0;
         }
-        enumerate(sub {
-            my ($i, $param) = @_;
-            my $param_arrays = $param->list_data;
-            $kvstore->init($i, $param_arrays->[0]);
-            $kvstore->pull($i, out => $param_arrays, priority => -$i);
-        }, $self->_params);
         if($update_on_kvstore)
         {
+            # optimizer preferably needs to be set before init for multiprecision
             $kvstore->set_optimizer($self->_optimizer);
         }
-        $self->_kv_store($kvstore);
-        $self->_update_on_kvstore($update_on_kvstore);
+        $self->kvstore($kvstore);
+        $self->update_on_kvstore($update_on_kvstore);
     }
     else
     {
-        $self->_kv_store(undef);
-        $self->_update_on_kvstore(undef)
+        $self->kvstore(undef);
+        $self->update_on_kvstore(undef);
     }
     $self->_kv_initialized(1);
 }
 
+method _row_sparse_pull($parameter, $out, $row_id)
+{
+    # initialize kv and params if not already
+    $self->_init_kvstore() unless $self->_kv_initialized;
+    $self->_init_params() if scalar(@{ $self->_params_to_init });
+    $self->kvstore->row_sparse_pull(
+        $self->_param2idx->{ $parameter->name },
+        out => $out,
+        row_ids => $row_id
+    );
+}
+
 =head2 step
 
         Makes one step of parameter update. Should be called after
-        `autograd.compute_gradient` and outside of `record()` scope.
+        `autograd->backward()` and outside of `record()` scope.
+
+        For normal parameter updates, `step()` should be used, which internally calls
+        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
+        gradients to perform certain transformation, such as in gradient clipping, then
+        you may want to manually call `allreduce_grads()` and `update()` separately.
 
         Parameters
         ----------
-        batch_size : int
+        $batch_size : Int
             Batch size of data processed. Gradient will be normalized by `1/batch_size`.
             Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
-        ignore_stale_grad : bool, optional, default=False
+        $ignore_stale_grad : Bool, optional, default=False
             If true, ignores Parameters with stale gradient (gradient that has not
             been updated by `backward` after last step) and skip update.
 =cut
 
 method step(Int $batch_size, Bool $ignore_stale_grad=0)
 {
-    if(not $self->_kv_initialized)
-    {
-        $self->_init_kvstore;
-    }
+    $self->_init_kvstore() unless $self->_kv_initialized;
+    $self->_init_params() if scalar(@{ $self->_params_to_init });
     $self->_optimizer->rescale_grad($self->_scale/$batch_size);
-    enumerate(sub {
-        my ($i, $param) = @_;
-        return if $param->grad_req eq 'null';
-        if(not $ignore_stale_grad)
+    $self->_allreduce_grads();
+    $self->_update($ignore_stale_grad);
+}
+
+=head2 allreduce_grads
+
+        For each parameter, reduce the gradients from different contexts.
+
+        Should be called after `autograd.backward()`, outside of `record()` scope,
+        and before `trainer.update()`.
+
+        For normal parameter updates, `step()` should be used, which internally calls
+        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
+        gradients to perform certain transformation, such as in gradient clipping, then
+        you may want to manually call `allreduce_grads()` and `update()` separately.
+=cut
+
+method allreduce_grads()
+{
+    $self->_init_kvstore() unless $self->_kv_initialized;
+    $self->_init_params() if scalar(@{ $self->_params_to_init });
+    assert(
+        (not ($self->kvstore and $self->update_on_kvstore)),
+        'allreduce_grads() when parameters are updated on kvstore '.
+        'is not supported. Try setting `update_on_kvstore` '.
+        'to False when creating trainer.'
+    );
+    $self->_allreduce_grads();
+}
+
+method _allreduce_grads()
+{
+    if($self->kvstore)
+    {
+        for(enumerate($self->params))
         {
-            for my $data (@{ $param->list_data })
+            my ($i, $param) = @$_;
+            if($param->grad_req ne 'null')
             {
-                if(not $data->_fresh_grad)
+                $self->kvstore->push($i, $param->list_grad(), priority=>-$i);
+                if(not $self->update_on_kvstore)
                 {
-                    AI::MXNet::Logging->warning(
-                        "Gradient of Parameter `%s` on context %s has not been updated ".
-                        "by backward since last `step`. This could mean a bug in your ".
-                        "model that maked it only use a subset of the Parameters (Blocks) ".
-                        "for this iteration. If you are intentionally only using a subset, ".
-                        "call step with ignore_stale_grad=True to suppress this ".
-                        "warning and skip updating of Parameters with stale gradient",
-                        $param->name, $data->context
-                    );
+                    $self->kvstore->pull($i, out => $param->list_grad(), priority=>-$i);
                 }
             }
         }
-        if($self->_kv_store)
-        {
-            $self->_kv_store->push($i, $param->list_grad, priority => -$i);
-            if($self->_update_on_kvstore)
-            {
-                $self->_kv_store->pull($i, out => $param->list_data, priority => -$i);
-                return;
-            }
-            else
-            {
-                $self->_kv_store->pull($i, out => $param->list_grad, priority => -$i);
-            }
-        }
-        for(zip($self->_updaters, $param->list_data, $param->list_grad)) {
-            my ($upd, $arr, $grad) = @$_;
-            if(not $ignore_stale_grad or $arr->_fresh_grad)
-            {
-                $upd->($i, $grad, $arr);
-                $arr->_fresh_grad(0);
-            }
-        }
-    }, $self->_params);
+    }
 }
 
 method learning_rate(Maybe [Num] $lr)
@@ -277,6 +400,91 @@ method set_learning_rate(Num $lr)
     $self->learning_rate($lr);
 }
 
+=head2 update
+
+        Makes one step of parameter update.
+
+        Should be called after autograd->backward() and outside of record() scope,
+        and after trainer->update`.
+
+
+        For normal parameter updates, step() should be used, which internally calls
+        allreduce_grads() and then update(). However, if you need to get the reduced
+        gradients to perform certain transformation, such as in gradient clipping, then
+        you may want to manually call allreduce_grads() and update() separately.
+
+        Parameters
+        ----------
+        $batch_size : Int
+            Batch size of data processed. Gradient will be normalized by `1/$batch_size`.
+            Set this to 1 if you normalized loss manually with $loss = mean($loss).
+        $ignore_stale_grad : Bool, optional, default=False
+            If true, ignores Parameters with stale gradient (gradient that has not
+            been updated by backward() after last step) and skip update.
+=cut
+
+method update(Int $batch_size, Bool $ignore_stale_grad=0)
+{
+    $self->_init_kvstore() unless $self->_kv_initialized;
+    $self->_init_params() if scalar(@{ $self->_params_to_init });
+    assert(
+        (not ($self->kvstore and $self->update_on_kvstore)),
+        'update() when parameters are updated on kvstore '.
+        'is not supported. Try setting `update_on_kvstore` '.
+        'to False when creating trainer.'
+    );
+    $self->_optimizer->rescale_grad($self->_scale/$batch_size);
+    $self->_update($ignore_stale_grad);
+}
+
+method _update(Bool $ignore_stale_grad=0):
+{
+    for(enumerate($self->params))
+    {
+        my ($i, $param) = @$_;
+        next if($param->grad_req eq 'null');
+
+        if(not $ignore_stale_grad)
+        {
+            for my $data (@{ $param->_check_and_get($param->_data, []) })
+            {
+                if(not $data->_fresh_grad)
+                {
+                    AI::MXNet::Logging->warning(
+                        "Gradient of Parameter '%s' on context %s has not been updated ".
+                        "by backward since last `step`. This could mean a bug in your ".
+                        "model that made it only use a subset of the Parameters (Blocks) ".
+                        "for this iteration. If you are intentionally only using a subset, ".
+                        "call step with ignore_stale_grad=True to suppress this ".
+                        "warning and skip updating of Parameters with stale gradient",
+                        $param->name, $data->context
+                    );
+                }
+            }
+        }
+        if($self->kvstore and $self->update_on_kvstore)
+        {
+            if($param->stype eq 'default')
+            {
+                # 'row_sparse' parameters are not pulled immediately - they're pulled
+                # in `SparseBlock.sparse_forward`
+                $self->kvstore->pull($i, out => $param->list_data(), priority=>-$i);
+            }
+            next;
+        }
+
+        for(zip($self->_updaters, $param->list_data(), $param->list_grad()))
+        {
+            my ($upd, $arr, $grad) = @$_;
+            if(not $ignore_stale_grad or $arr->_fresh_grad)
+            {
+                $upd->($i, $grad, $arr);
+                $arr->_fresh_grad(0);
+            }
+        }
+    }
+}
+
 =head2 save_states
 
         Saves trainer states (e.g. optimizer, momentum) to a file.
@@ -290,14 +498,17 @@ method set_learning_rate(Num $lr)
 method save_states(Str $fname)
 {
     assert(defined $self->_optimizer);
-    if($self->_update_on_kvstore)
+    $self->_init_kvstore() unless $self->_kv_initialized;
+    $self->_init_params() if scalar(@{ $self->_params_to_init });
+
+    if($self->update_on_kvstore)
     {
-        $self->_kv_store->save_optimizer_states($fname, dump_optimizer=>1);
+        $self->kvstore->save_optimizer_states($fname, dump_optimizer=>1);
     }
     else
     {
         open(F, ">$fname") or Carp::confess("can not open $fname: $1");
-        print F $self->_updaters->[0]->get_states(dump_optimizer=>1);
+        print F $self->_updaters->[0]->get_states(dump_optimizer => 1);
         close(F);
     }
 }
@@ -314,10 +525,14 @@ method save_states(Str $fname)
 
 method load_states(Str $fname)
 {
-    if($self->_update_on_kvstore)
+    $self->_init_kvstore() unless $self->_kv_initialized;
+    $self->_init_params() if scalar(@{ $self->_params_to_init });
+
+    if($self->update_on_kvstore)
     {
-        $self->_kv_store->load_optimizer_states($fname);
-        $self->_optimizer($self->_kv_store->_updater->optimizer);
+        $self->kvstore->load_optimizer_states($fname);
+        $self->_optimizer($self->kvstore->_updater->optimizer);
+        $self->_optimizer->param_dict({ map { $_->[0] => $_->[1] } enumerate($self->params) });
     }
     else
     {
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Utils.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Utils.pm
index 6acb662..393f4fc 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Utils.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Utils.pm
@@ -161,7 +161,7 @@ method split_and_load(
     {
         return [$data->as_in_context($ctx_list->[0])];
     }
-    my $slices = __PACKAGE__->split_data($data, scalar(@$ctx_list), $batch_axis, $even_split);
+    my $slices = __PACKAGE__->split_data($data, scalar(@$ctx_list), $batch_axis, even_split => $even_split);
     my @ret;
     for(zip($slices, $ctx_list)) {
         my ($i, $ctx) = @$_;
@@ -177,20 +177,31 @@ method split_and_load(
 
 method clip_global_norm(ArrayRef[AI::MXNet::NDArray] $arrays, Num $max_norm)
 {
+    my $_norm = sub { my ($array) = @_;
+        if($array->stype eq 'default')
+        {
+            my $x = $array->reshape([-1]);
+            return AI::MXNet::NDArray->dot($x, $x);
+        }
+        return $array->norm->square;
+    };
     assert(@$arrays > 0);
-    my $total_norm = 0;
-    for my $arr (@$arrays)
+    my $ctx = $arrays->[0]->context;
+    my $total_norm = AI::MXNet::NDArray->add_n(map { $_norm->($_)->as_in_context($ctx) } @$arrays);
+    $total_norm = $total_norm->sqrt->asscalar;
+    if(lc($total_norm) eq 'nan' or $total_norm =~ /inf/i)
     {
-        $arr = $arr->reshape([-1]);
-        $total_norm += AI::MXNet::NDArray->dot($arr, $arr);
+        AI::MXNet::Logging->warning('nan or inf is detected. Clipping results will be undefined.');
     }
-    $total_norm = sqrt($total_norm->asscalar);
     my $scale = $max_norm / ($total_norm + 1e-8);
-    if($scale < 1)
+    if($scale < 1.0)
     {
-        $_ *= $scale for @{ $arrays };
+        for my $arr (@$arrays)
+        {
+            $arr *= $scale;
+        }
     }
-    return $total_norm
+    return $total_norm;
 }
 
 =head2 check_sha1
@@ -277,4 +288,28 @@ func download(Str $url, Maybe[Str] :$path=, Bool :$overwrite=0, Maybe[Str] :$sha
     return $fname
 }
 
+package AI::MXNet::Gluon::Utils::HookHandle;
+use Mouse;
+use AI::MXNet::Base;
+use Scalar::Util qw(refaddr);
+has [qw/_hooks_dict_ref/] => (is => 'rw', init_arg => undef, weak_ref => 1);
+has [qw/_id/]             => (is => 'rw', init_arg => undef);
+
+method attach(Hash::Ordered $hooks_dict, $hook)
+{
+    assert((not $self->_hooks_dict_ref), 'The same handle cannot be attached twice.');
+    $self->_id(refaddr($hook));
+    $hooks_dict->set($self->_id, $hook);
+    $self->_hooks_dict_ref($hooks_dict);
+}
+
+method detach()
+{
+    my $hooks_dict = $self->_hooks_dict_ref;
+    if($hooks_dict and $hooks_dict->exists($self->_id))
+    {
+        $hooks_dict->delete($self->_id);
+    }
+}
+
 1;
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Module.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Module.pm
index 231d63b..16c9a92 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Module.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Module.pm
@@ -38,6 +38,22 @@ use List::Util qw(max);
 use Data::Dumper ();
 use Mouse;
 
+func _create_sparse_kvstore(Maybe[Str|AI::MXNet::KVStore] $kvstore)
+{
+    # always update on kvstore
+    my $update_on_kvstore = 1;
+    my $kv;
+    if(blessed $kvstore)
+    {
+        $kv = $kvstore;
+    }
+    else
+    {
+        $kv = AI::MXNet::KVStore->create($kvstore);
+    }
+    return ($kv, $update_on_kvstore);
+}
+
 func _create_kvstore(
     Maybe[Str|AI::MXNet::KVStore] $kvstore,
     Int                           $num_device,
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
index c7dac63..3177a37 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
@@ -80,7 +80,7 @@ method STORABLE_thaw($cloning, $buf, $writable)
 
 method split_array(@args)
 {
-     $self->shape->[0] > 1 ? $self->split(num_outputs => $self->shape->[0], squeeze_axis => 1, axis => 0) : [$self];
+     $self->shape->[0] > 1 ? $self->split(num_outputs => $self->shape->[0], squeeze_axis => @{ $self->shape } > 1 ? 1 : 0, axis => 0) : [$self];
 }
 
 method at(Index @indices)
@@ -368,37 +368,26 @@ method _at(Index $idx)
 
 =head2 reshape
 
-    Returns a reshaped NDArray that shares the memory with current one.
+    Returns a **view** of this array with a new shape without altering any data.
     One shape dimension can be -1. In this case, the value is inferred
     from the length of the array and remaining dimensions.
 
     Parameters
     ----------
-    new_shape : Shape
+    $new_shape : Shape
         new shape of NDArray
+    :$reverse : bool, default 0
+        If true then the special values are inferred from right to left.
 =cut
 
-method reshape(ArrayRef[Int] $new_shape)
+method reshape(ArrayRef[Int] $new_shape, Bool :$reverse=0)
 {
-    my $i = -1;
-    my @inferred = map { $i++; $_ == -1 ? ($i) : () } @$new_shape;
-    assert((@inferred <= 1), 'Only one dimension can be inferred.');
-    $i = -1;
-    my @keep = map { $i++; $_ == 0 ? ($i) : () } @$new_shape;
-    my $shape = $self->shape;
-    if(@keep)
-    {
-        @{$new_shape}[@keep] = @{$shape}[@keep];
-    }
-    if(@inferred)
-    {
-        $new_shape->[$inferred[0]] = product(@{ $shape })/product(map { abs($_) } @{ $new_shape });
-    }
     my $handle = check_call(
-                    AI::MXNetCAPI::NDArrayReshape(
+                    AI::MXNetCAPI::NDArrayReshape64(
                         $self->handle,
                         scalar(@$new_shape),
-                        $new_shape
+                        $new_shape,
+                        $reverse
                     )
     );
     return __PACKAGE__->_ndarray_cls($handle, $self->writable);
@@ -1297,6 +1286,42 @@ method load(Str $filename)
     }
 }
 
+=head2 load_frombuffer
+
+    Loads an array dictionary or list from a buffer
+
+    See more details in 'save'.
+
+    Parameters
+    ----------
+    buf : str
+        Binary string containing contents of a file.
+
+    Returns
+    -------
+    array ref of AI::MXNet::NDArray, AI::MXNet::NDArrayRowSparseNDArray or AI::MXNet::NDArray::CSR, or
+    hash ref of AI::MXNet::NDArray, AI::MXNet::NDArrayRowSparseNDArray or AI::MXNet::NDArray::CSR
+        Loaded data.
+=cut
+
+method load_frombuffer(Str $buf)
+{
+    my ($handles, $names) = check_call(AI::MXNetCAPI::NDArrayLoadFromBuffer($buf, length($buf)));
+    if (not @$names)
+    {
+        return [map { __PACKAGE__->_ndarray_cls($_) } @$handles];
+    }
+    else
+    {
+        my $n = @$names;
+        my $h = @$handles;
+        confess("Handles [$h] and names [$n] count mismatch") unless $h == $n;
+        my %ret;
+        @ret{ @$names } = map { __PACKAGE__->_ndarray_cls($_) } @$handles;
+        return \%ret;
+    }
+}
+
 =head2 save
 
     Save array ref of NDArray or hash of str->NDArray to a binary file.
@@ -1586,6 +1611,8 @@ method backward(Maybe[AI::MXNet::NDArray] :$out_grad=, Bool :$retain_graph=0, Bo
 
 method CachedOp(@args) { AI::MXNet::CachedOp->new(@args) }
 
+method histogram(@args) { __PACKAGE__->_histogram(@args%2 ? ('data', @args) : @args) }
+
 my $lvalue_methods = join "\n", map {"use attributes 'AI::MXNet::NDArray', \\&AI::MXNet::NDArray::$_, 'lvalue';"}
 qw/at slice aspdl asmpdl reshape copy sever T astype as_in_context copyto empty zero ones full
                        array/;
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Optimizer.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Optimizer.pm
index b8ad043..fd13164 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Optimizer.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Optimizer.pm
@@ -1314,7 +1314,7 @@ method update(
     my $lr = $self->_get_lr($index);
     my $wd = $self->_get_wd($index);
     $self->_update_count($index);
-    my $is_sparse = ($weight->stype eq 'row_sparse' and $grad->stype eq 'row_sparse') ? 1 : 0;
+    my $is_sparse = $grad->stype eq 'row_sparse' ? 1 : 0;
     my $history = $state;
     if($is_sparse)
     {
@@ -1950,7 +1950,15 @@ method set_states($states)
 
 method get_states(Bool $dump_optimizer=0)
 {
-    return freeze($dump_optimizer ? [$self->states, $self->optimizer] : $self->states);
+    if($dump_optimizer)
+    {
+        my $param_dict = $self->optimizer->param_dict;
+        $self->optimizer->param_dict({});
+        my $freezed = freeze([$self->states, $self->optimizer]);
+        $self->optimizer->param_dict($param_dict);
+        return $freezed;
+    }
+    return freeze($self->states);
 }
 
 package AI::MXNet::Optimizer;
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm
index dcd765c..8a47a12 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Random.pm
@@ -43,19 +43,39 @@ use AI::MXNet::Function::Parameters;
 
     Parameters
     ----------
-    seed_state : int
-        The random number seed to set to all devices.
+    $seed_state : int
+        The random number seed.
 
+    :$ctx : [Str|AI::MXNet::Context]
+        The device context of the generator. The default Str is "all" which means seeding random
+        number generators of all devices.
     Notes
     -----
-    The random number generator of mxnet is by default device specific.
-    This means if you set the same seed, the random number sequence
-    generated from GPU0 can be different from CPU.
+    Random number generators in MXNet are device specific.
+    mx->random->seed($seed_state) sets the state of each generator using $seed_state and the
+    device id. Therefore, random numbers generated from different devices can be different
+    even if they are seeded using the same seed.
+
+    To produce identical random number sequences independent of the device id,
+    set optional ctx argument. 
+    For example mx->random->seed($seed_state, ctx => mx->gpu(0))
+    This produces the same sequence of random numbers independent
+    of the device id, but the sequence can be different on different kind of devices as MXNet's
+    random number generators for CPU and GPU use different algorithms.
 =cut
 
-method seed(Int $seed_state)
+method seed(Int $seed_state, Str|AI::MXNet::Context :$ctx='all')
 {
-    check_call(AI::MXNetCAPI::RandomSeed($seed_state));
+    if(not ref $ctx)
+    {
+        confess("ctx argument could be either string 'all' or AI::MXNet::Context")
+            unless $ctx eq 'all';
+        check_call(AI::MXNetCAPI::RandomSeed($seed_state));
+    }
+    else
+    {
+        check_call(AI::MXNetCAPI::RandomSeedContext($seed_state, $ctx->device_type_id, $ctx->device_id));
+    }
 }
 
 sub AUTOLOAD {
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm
index 59ecd9c..bccf483 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm
@@ -229,6 +229,16 @@ method hypot(AI::MXNet::Symbol|Num $other)
     );
 }
 
+method reshape(@args)
+{
+    if(@args%2)
+    {
+        unshift @args, 'shape';
+    }
+    return $self->SUPER::reshape(@args);
+}
+
+
 method deepcopy()
 {
     my $handle = check_call(AI::MXNetCAPI::SymbolCopy($self->handle));
@@ -1495,6 +1505,8 @@ sub  _ufunc_helper
     }
 }
 
+method histogram(@args) { __PACKAGE__->_histogram(@args%2 ? ('data', @args) : @args) }
+
 sub contrib { 'AI::MXNet::Contrib::Symbol' }
 sub random  { 'AI::MXNet::Symbol::Random' }
 sub sparse  { 'AI::MXNet::Symbol::Sparse' }
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Visualization.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Visualization.pm
index 19c38bc..4790560 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/Visualization.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/Visualization.pm
@@ -340,6 +340,18 @@ method plot_network(
         {
             $attr{fillcolor} = $cm[3];
         }
+        elsif($op eq 'Flatten')
+        {
+            $label = $op;
+        }
+        elsif($op eq 'Dropout')
+        {
+            $label = "$op ($node->{attrs}{p})";
+        }
+        elsif($op eq 'Reshape')
+        {
+            $label = "$op $node->{attrs}{shape}";
+        }
         elsif($op eq 'Activation' or $op eq 'LeakyReLU')
         {
             $label = "$op\n$node->{attrs}{act_type}";
diff --git a/perl-package/AI-MXNet/t/test_conv.t b/perl-package/AI-MXNet/t/test_conv.t
index fc7d4ae..19c302b 100644
--- a/perl-package/AI-MXNet/t/test_conv.t
+++ b/perl-package/AI-MXNet/t/test_conv.t
@@ -22,7 +22,7 @@ use AI::MXNet::TestUtils qw(GetMNIST_ubyte);
 use Test::More tests => 1;
 
 ## speed up the tests when gpu present
-my $gpu_present = (`perl -e 'use AI::MXNet qw(mx); print mx->nd->ones([1], ctx => mx->gpu(0))->asscalar' 2>/dev/null` eq '1');
+my $gpu_present = mx->context->num_gpus;
 
 # symbol net
 my $batch_size = 100;
diff --git a/perl-package/AI-MXNet/t/test_cuda_module.t b/perl-package/AI-MXNet/t/test_cuda_module.t
index c2ad05e..4576e76 100644
--- a/perl-package/AI-MXNet/t/test_cuda_module.t
+++ b/perl-package/AI-MXNet/t/test_cuda_module.t
@@ -19,7 +19,7 @@ use strict;
 use warnings;
 use AI::MXNet qw(mx);
 use Test::More tests => 3;
-my $gpu_present = (`perl -e 'use AI::MXNet qw(mx); print mx->nd->ones([1], ctx => mx->gpu(0))->asscalar' 2>/dev/null` eq '1');
+my $gpu_present = mx->context->num_gpus;
 
 sub test_cuda_rtc
 {
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm b/perl-package/AI-MXNet/t/test_engine.t
similarity index 68%
copy from perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm
copy to perl-package/AI-MXNet/t/test_engine.t
index 3dc7a06..4cf5744 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/AutoLoad.pm
+++ b/perl-package/AI-MXNet/t/test_engine.t
@@ -15,21 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 
-package AI::MXNet::AutoLoad;
 use strict;
 use warnings;
+use AI::MXNet qw(mx);
+use Test::More tests => 2;
 
-sub AUTOLOAD
+sub test_bulk
 {
-    my ($class) = @_;
-    my ($prefix, $real_class) = $class->config;
-    my ($name) = our $AUTOLOAD =~ /::(\w+)$/;
-    my $sub = "_${prefix}_$name";
-    {
-        no strict 'refs';
-        *{"$class::$name"} = sub { shift; $real_class->$sub(@_); };
-    }
-    goto $class->can($name);
+    my $x;
+    mx->engine->bulk(10, sub {
+        $x = mx->nd->ones([10]);
+        $x *= 2;
+        $x += 1;
+        $x->wait_to_read();
+        $x += 1;
+        ok(($x->aspdl == 4)->all);
+        for my $i (1..100)
+        {
+            $x += 1;
+        }
+    });
+    ok(($x->aspdl == 104)->all);
 }
 
-1;
+test_bulk();
\ No newline at end of file
diff --git a/perl-package/AI-MXNet/t/test_executor.t b/perl-package/AI-MXNet/t/test_executor.t
index 2131c79..0c01bd5 100644
--- a/perl-package/AI-MXNet/t/test_executor.t
+++ b/perl-package/AI-MXNet/t/test_executor.t
@@ -17,7 +17,7 @@
 
 use strict;
 use warnings;
-use Test::More tests => 2283;
+use Test::More tests => 2285;
 use AI::MXNet qw(mx);
 use AI::MXNet::TestUtils qw(reldiff pdl_maximum pdl_minimum);
 use PDL;
@@ -181,6 +181,12 @@ sub test_reshape
     # test base exec forward
     $exe->forward(0);
     ok(($new_exe->outputs->[0]->aspdl == 4)->all);
+    $new_exe = $exe->reshape({ x=>[6,4] }, allow_up_sizing=>1);
+    # data ndarray is not shared between exe and new_exe
+    $new_exe->arg_arrays->[0] .= 0;
+    ok(($exe->arg_arrays->[0]->aspdl == 1)->all);
+    # weight ndarray is shared between exe and new_exe
+    ok(($new_exe->arg_arrays->[1]->aspdl == 1)->all);
 }
 
 test_bind(0);
diff --git a/perl-package/AI-MXNet/t/test_gluon.t b/perl-package/AI-MXNet/t/test_gluon.t
index ff7e2a6..3212722 100644
--- a/perl-package/AI-MXNet/t/test_gluon.t
+++ b/perl-package/AI-MXNet/t/test_gluon.t
@@ -17,11 +17,11 @@
 
 use strict;
 use warnings;
-use Test::More tests => 119;
+use Test::More tests => 232;
 use AI::MXNet qw(mx);
 use AI::MXNet::Gluon qw(gluon);
 use AI::MXNet::Gluon::NN qw(nn);
-use AI::MXNet::TestUtils qw(almost_equal);
+use AI::MXNet::TestUtils qw(almost_equal dies_ok);
 use Scalar::Util qw(refaddr);
 use AI::MXNet::Base;
 
@@ -34,6 +34,8 @@ sub test_parameter
     ok($p->data(mx->cpu(1))->context eq mx->cpu(1));
     is_deeply($p->data(mx->cpu(0))->shape, [10, 10]);
     ok($p->var->name eq  'weight');
+    ok($p->grad(mx->cpu(0))->stype eq 'default');
+    ok($p->data(mx->cpu(0))->stype eq 'default');
 
     $p->reset_ctx(ctx=>[mx->cpu(1), mx->cpu(2)]);
     is_deeply($p->list_ctx, [mx->cpu(1), mx->cpu(2)]);
@@ -41,29 +43,187 @@ sub test_parameter
 
 test_parameter();
 
+sub test_invalid_parameter_stype
+{
+    dies_ok(sub { gluon->Parameter('weight', shape=>[10, 10], stype=>'invalid') });
+}
+
+test_invalid_parameter_stype();
+
+sub test_invalid_parameter_grad_stype
+{
+    dies_ok(sub { gluon->Parameter('weight', shape=>[10, 10], grad_stype=>'invalid') });
+}
+
+test_invalid_parameter_grad_stype();
+
+sub test_sparse_parameter
+{
+    my $p = gluon->Parameter('weight', shape=>[10, 10], stype=>'row_sparse', grad_stype=>'row_sparse');
+    $p->initialize(init=>'xavier', ctx=>[mx->cpu(0), mx->cpu(1)]);
+    my $row_id = mx->nd->arange(start => 0, stop => 10, ctx=>mx->cpu(1));
+    ok(@{ $p->list_grad } == 2);
+    # getting row_sparse data without trainer throws an exception
+    dies_ok(sub { $p->list_row_sparse_data($row_id) });
+    my $trainer = gluon->Trainer([$p], 'sgd');
+    ok(@{ $p->list_row_sparse_data($row_id) } == 2);
+    my $weight = $p->row_sparse_data($row_id);
+    ok($weight->context eq mx->cpu(1));
+    is_deeply($weight->shape, [10, 10]);
+    ok($weight->stype eq 'row_sparse');
+    ok($p->var->name eq 'weight');
+    ok($p->var->attr('__storage_type__') eq STORAGE_TYPE_STR_TO_ID->{row_sparse});
+    ok($p->grad(mx->cpu(0))->stype eq 'row_sparse');
+
+    $p->reset_ctx(ctx=>[mx->cpu(1), mx->cpu(2)]);
+    is_deeply($p->list_ctx, [mx->cpu(1), mx->cpu(2)]);
+}
+
+test_sparse_parameter();
+
+sub test_parameter_invalid_access
+{
+    # cannot call data on row_sparse parameters
+    my $p0 = gluon->Parameter('weight', shape=>[10, 10], stype=>'row_sparse', grad_stype=>'row_sparse');
+    $p0->initialize(init=>'xavier', ctx=>[mx->cpu(0), mx->cpu(1)]);
+    dies_ok(sub { $p0->data });
+    dies_ok(sub { $p0->list_data });
+    my $row_id = mx->nd->arange(start => 0, stop => 10);
+    # cannot call row_sparse_data on dense parameters
+    my $p1 = gluon->Parameter('weight', shape=>[10, 10]);
+    $p1->initialize(init=>'xavier', ctx=>[mx->cpu(0), mx->cpu(1)]);
+    dies_ok(sub { $p1->row_sparse_data($row_id->copyto(mx->cpu(0))) });
+    dies_ok(sub { $p1->list_row_sparse_data($row_id) });
+}
+
+test_parameter_invalid_access();
+
 sub test_paramdict
 {
-    my $params = gluon->ParameterDict('net_');
-    $params->get('weight', shape=>[10, 10]);
-    is_deeply([$params->keys], ['net_weight']);
-    $params->initialize(ctx=>mx->cpu());
-    $params->save('test.params');
-    $params->load('test.params', ctx => mx->cpu());
+    my $ctx = mx->cpu(1);
+    my $params0 = gluon->ParameterDict('net_');
+    $params0->get('w0', shape=>[10, 10]);
+    $params0->get('w1', shape=>[10, 10], stype=>'row_sparse');
+    my $all_row_ids = mx->nd->arange(start => 0, stop => 10, ctx=>$ctx);
+    # check param names
+    is_deeply([$params0->keys()], ['net_w0', 'net_w1']);
+    $params0->initialize(ctx=>$ctx);
+    my $trainer0 = gluon->Trainer($params0, 'sgd');
+    my $prev_w0 = $params0->get('w0')->data($ctx);
+    my $prev_w1 = $params0->get('w1')->row_sparse_data($all_row_ids);
+    # save params
+    $params0->save('test_paramdict.params');
+
+    # load params
+    my $params1 = gluon->ParameterDict('net_');
+    $params1->get('w0', shape=>[10, 10]);
+    $params1->get('w1', shape=>[10, 10], stype=>'row_sparse');
+    $params1->load('test_paramdict.params', ctx=>$ctx);
+    my $trainer1 = gluon->Trainer($params1, 'sgd');
+
+    # compare the values before and after save/load
+    my $cur_w0 = $params1->get('w0')->data($ctx);
+    my $cur_w1 = $params1->get('w1')->row_sparse_data($all_row_ids);
+    ok(almost_equal($prev_w0->aspdl, $cur_w0->aspdl));
+    ok(almost_equal($prev_w1->aspdl, $cur_w1->aspdl));
+
+    # create a new param dict with dense params, and load from the checkpoint
+    # of sparse & dense params
+    my $params2 = gluon->ParameterDict('net_');
+    $params2->get('w0', shape=>[10, 10]);
+    $params2->get('w1', shape=>[10, 10]);
+    $params2->load('test_paramdict.params', ctx=>$ctx);
+
+    # compare the values before and after save/load
+    $cur_w0 = $params2->get('w0')->data($ctx);
+    $cur_w1 = $params2->get('w1')->data($ctx);
+    ok(almost_equal($prev_w0->aspdl, $cur_w0->aspdl));
+    ok(almost_equal($prev_w1->aspdl, $cur_w1->aspdl));
 }
 
 test_paramdict();
 
+sub test_parameter_row_sparse_data
+{
+    my $ctx0 = mx->cpu(1);
+    my $ctx1 = mx->cpu(2);
+    my $dim0 = 4;
+    my $x = gluon->Parameter('x', shape=>[$dim0, 2], stype=>'row_sparse');
+    $x->initialize(init=>'xavier', ctx=>[$ctx0, $ctx1]);
+    my $trainer = gluon->Trainer([$x], 'sgd');
+    my $x_param = $x->_data->[0]->copy();
+    is($x_param->stype, 'row_sparse');
+    my $row_id_0 = mx->nd->array([0,1], ctx=>$ctx0);
+    my $retained_0 = $x->row_sparse_data($row_id_0);
+    my $retained_target_0 = mx->nd->sparse->retain($x_param, $row_id_0->as_in_context($ctx0));
+    ok(almost_equal($retained_0->aspdl, $retained_target_0->aspdl));
+    is($retained_0->context, $ctx0);
+    my $row_id_1 = mx->nd->arange(start => 0, stop => $dim0, ctx=>$ctx1);
+    my $retained_1 = $x->row_sparse_data($row_id_1);
+    my $retained_target_1 = $x_param;
+    ok(almost_equal($retained_1->aspdl, $retained_target_1->aspdl));
+    is($retained_1->context, $ctx1);
+    my $row_id_2 = mx->nd->array([0,1,2]);
+    my $retained_2 = $x->list_row_sparse_data($row_id_2);
+    my $retained_target_2 = mx->nd->sparse->retain($x_param, $row_id_2->as_in_context($ctx0));
+    ok(almost_equal($retained_2->[0]->aspdl, $retained_target_2->aspdl));
+}
+
+test_parameter_row_sparse_data();
+
+sub test_constant
+{
+    package Test {
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::HybridBlock';
+        sub BUILD
+        {
+            my $self = shift;
+            $self->value(mx->nd->array([[1,2], [3,4]])->aspdl);
+            $self->const($self->params->get_constant('const', $self->value));
+        }
+        sub hybrid_forward
+        {
+            my ($self, $F, $x, $name, $const) = @_;
+            return $x + $const;
+        }
+    };
+
+    my $test = Test->new();
+    $test->initialize();
+    my $trainer = gluon->Trainer(
+        $test->collect_params(), 'sgd',
+        {learning_rate => 1.0, momentum => 0.5}
+    );
+
+    my ($x, $y);
+    mx->autograd->record(sub {
+        $x = mx->nd->ones([2,2]);
+        $x->attach_grad();
+        $y = $test->($x);
+        $y->backward();
+    });
+
+    $trainer->step(1);
+
+    ok(($test->const->data->aspdl == $test->value)->all);
+    ok(($x->grad->aspdl == 1)->all);
+}
+
+test_constant();
+
 package Net;
 use AI::MXNet::Gluon::Mouse;
 use AI::MXNet::Function::Parameters;
 extends 'AI::MXNet::Gluon::Block';
+has 'in_units' => (is => 'rw', default => 0);
 
 sub BUILD
 {
     my $self = shift;
     $self->name_scope(sub {
-        $self->dense0(nn->Dense(5, in_units=>5));
-        $self->dense1(nn->Dense(5, in_units=>5));
+        $self->dense0(nn->Dense(5, in_units=>$self->in_units));
+        $self->dense1(nn->Dense(5, in_units=>$self->in_units));
     });
 }
 
@@ -76,17 +236,71 @@ package main;
 
 sub test_parameter_sharing
 {
-    my $net1 = Net->new(prefix=>'net1_');
+    my $net1 = Net->new(prefix=>'net1_', in_units => 5);
     my $net2 = Net->new(prefix=>'net2_', params=>$net1->collect_params());
     $net1->collect_params()->initialize();
     $net2->(mx->nd->zeros([3, 5]));
-    $net1->save_params('net1.params');
+    $net1->save_parameters('net1.params');
     my $net3 = Net->new(prefix=>'net3_');
-    $net3->load_params('net1.params', ctx => mx->cpu());
+    $net3->load_parameters('net1.params', ctx => mx->cpu());
+    my $net4 = Net->new(prefix=>'net4_');
+    my $net5 = Net->new(prefix=>'net5_', in_units=>5, params=>$net4->collect_params());
+    $net4->collect_params()->initialize();
+    $net5->(mx->nd->zeros([3, 5]));
+    $net4->save_parameters('net4.params');
+    my $net6 = Net->new(prefix=>'net6_');
+    $net6->load_parameters('net4.params', ctx => mx->cpu());
 }
 
 test_parameter_sharing();
 
+sub test_parameter_str
+{
+    package Net1 {
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::Block';
+        sub BUILD
+        {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->dense0(nn->Dense(10, in_units=>5, use_bias=>0));
+            });
+        }
+    };
+    my $net = Net1->new(prefix=>'net1_');
+    my @lines = split(/\n/, $net->collect_params());
+    ok($lines[0] eq 'net1_ (');
+    ok($lines[1] =~ /net1_dense0_weight/);
+    ok($lines[1] =~ /\(10, 5\)/);
+    ok($lines[1] =~ /float32/);
+    ok($lines[2] eq ')');
+}
+
+test_parameter_str();
+
+sub test_collect_parameters
+{
+    my $net = nn->HybridSequential(prefix=>"test_");
+    $net->name_scope(sub {
+        $net->add(nn->Conv2D(10, 3));
+        $net->add(nn->Dense(10, activation=>'relu'));
+    });
+    is_deeply(
+        [$net->collect_params->keys],
+        ['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias']
+    );
+    is_deeply(
+        [$net->collect_params('.*weight')->keys],
+        ['test_conv0_weight', 'test_dense0_weight']
+    );
+    is_deeply(
+        [$net->collect_params('test_conv0_bias|test_dense0_bias')->keys],
+        ['test_conv0_bias', 'test_dense0_bias']
+    )
+};
+
+test_collect_parameters();
+
 sub test_basic
 {
     my $model = nn->Sequential();
@@ -165,7 +379,7 @@ sub test_symbol_block
     my $outputs = $model->($inputs)->get_internals();
     my $smodel = gluon->SymbolBlock($outputs, $inputs, params=>$model->collect_params);
 
-    ok(@{ $smodel->(mx->nd->zeros([16, 10])) } == 14);
+    ok($smodel->(mx->nd->zeros([16, 10])) == 14);
     my $out = $smodel->(mx->sym->var('in'));
     ok(@{ $out } == @{ $outputs->list_outputs() });
 
@@ -183,6 +397,32 @@ sub test_symbol_block
 
 test_symbol_block();
 
+sub test_sparse_symbol_block
+{
+    my $data = mx->sym->var('data');
+    my $weight = mx->sym->var('weight', stype=>'row_sparse');
+    my $bias = mx->sym->var('bias');
+    my $out = mx->sym->broadcast_add(mx->sym->dot($data, $weight), $bias);
+    # an exception is expected when creating a SparseBlock w/ sparse param
+    dies_ok(sub { gluon->SymbolBlock($out, $data) });
+}
+
+test_sparse_symbol_block();
+
+sub test_sparse_hybrid_block0
+{
+    my $params = gluon->ParameterDict('net_');
+    $params->get('weight', shape=>[5,5], stype=>'row_sparse', dtype=>'float32', allow_deferred_init => 1);
+    $params->get('bias', shape=>[5], dtype=>'float32', allow_deferred_init => 1);
+    my $net = nn->Dense(5, params=>$params);
+    $net->initialize();
+    my $x = mx->nd->ones([2,5]);
+    # an exception is expected when forwarding a HybridBlock w/ sparse param
+    dies_ok(sub { $net->($x) });
+}
+
+test_sparse_hybrid_block0();
+
 sub check_layer_forward
 {
     my ($layer, $dshape) = @_;
@@ -314,6 +554,7 @@ sub test_pool
         nn->MaxPool1D(3),
         nn->MaxPool1D(3, 2),
         nn->AvgPool1D(),
+        nn->AvgPool1D(count_include_pad=>0),
         nn->GlobalAvgPool1D(),
     );
     for my $layer (@layers1d)
@@ -326,6 +567,7 @@ sub test_pool
         nn->MaxPool2D([3, 3]),
         nn->MaxPool2D(3, 2),
         nn->AvgPool2D(),
+        nn->AvgPool2D(count_include_pad=>0),
         nn->GlobalAvgPool2D(),
     );
     for my $layer (@layers2d)
@@ -338,6 +580,7 @@ sub test_pool
         nn->MaxPool3D([3, 3, 3]),
         nn->MaxPool3D(3, 2),
         nn->AvgPool3D(),
+        nn->AvgPool3D(count_include_pad=>0),
         nn->GlobalAvgPool3D(),
     );
     for my $layer (@layers3d)
@@ -367,6 +610,30 @@ sub test_batchnorm
 
 test_batchnorm();
 
+sub test_instancenorm
+{
+    my $layer = nn->InstanceNorm(in_channels=>10);
+    check_layer_forward($layer, [2, 10, 10, 10]);
+}
+
+test_instancenorm();
+
+sub test_layernorm
+{
+    my $layer = nn->LayerNorm(in_channels=>10);
+    check_layer_forward($layer, [2, 10, 10, 10]);
+}
+
+test_layernorm();
+
+sub test_reflectionpad
+{
+    my $layer = nn->ReflectionPad2D(3);
+    check_layer_forward($layer, [2, 3, 24, 24]);
+}
+
+test_reflectionpad();
+
 sub test_reshape
 {
     my $x = mx->nd->ones([2, 4, 10, 10]);
@@ -460,71 +727,6 @@ sub test_flatten
 
 test_flatten();
 
-sub test_trainer
-{
-    my $dict_equ = sub { my ($a, $b) = @_;
-        is_deeply({ map { $_ => 1 } keys %$a }, { map { $_ => 1 } keys %$b });
-        for my $k (keys %$a)
-        {
-            ok(($a->{$k}->aspdl == $b->{$k}->aspdl)->all);
-        }
-    };
-    my $x = gluon->Parameter('x', shape=>[10]);
-    $x->initialize(ctx=>[mx->cpu(0), mx->cpu(1)], init=>'zeros');
-    my $trainer = gluon->Trainer([$x], 'sgd', {'learning_rate'=> 1.0, 'momentum'=> 0.5});
-    my $y;
-    mx->autograd->record(sub {
-        for my $w (@{ $x->list_data() })
-        {
-            $y = $w + 1;
-            $y->backward();
-        }
-    });
-    $trainer->step(1);
-
-    ok(($x->data(mx->cpu(1))->aspdl == -2)->all);
-
-    $x->lr_mult(0.5);
-
-    mx->autograd->record(sub {
-        for my $w (@{ $x->list_data() })
-        {
-            $y = $w + 1;
-            $y->backward();
-        }
-    });
-    $trainer->step(1);
-
-    ok(($x->data(mx->cpu(1))->aspdl == -4)->all);
-
-    $trainer->save_states('test.states');
-    my $states;
-    if($trainer->_update_on_kvstore)
-    {
-        $states = { %{ $trainer->_kv_store->_updater->states } };
-    }
-    else
-    {
-        $states = { %{ $trainer->_updaters->[0]->states } };
-    }
-    $trainer->load_states('test.states');
-    if($trainer->_update_on_kvstore)
-    {
-        $dict_equ->($trainer->_kv_store->_updater->states, $states);
-        ok($trainer->_optimizer eq $trainer->_kv_store->_updater->optimizer);
-    }
-    else
-    {
-        for my $updater (@{ $trainer->_updaters })
-        {
-            $dict_equ->($updater->states, $states);
-        }
-        ok($trainer->_optimizer eq $trainer->_updaters->[0]->optimizer);
-    }
-}
-
-test_trainer();
-
 sub test_block_attr_hidden
 {
     my $b = gluon->Block();
@@ -565,23 +767,554 @@ sub test_block_attr_regular
     $b->c(gluon->Block());
     my $c2 = gluon->Block();
     $b->c($c2);
-    ok(refaddr($b->c) == refaddr($c2) and refaddr($b->_children->[0]) == refaddr($c2));
+    ok(refaddr($b->c) == refaddr($c2) and refaddr(($b->_children->values)[0]) == refaddr($c2));
 }
 
 test_block_attr_regular();
 
+sub test_block_attr_list_of_block
+{
+    package Model1 {
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::Block';
+        sub BUILD
+        {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->layers([map { nn->Dense($_ * 10) } 0..5]);
+            });
+        }
+    };
+    package Model2 {
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::Block';
+        sub BUILD
+        {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->layers({});
+                $self->layers->{a} = [map { nn->Dense($_ * 10) } 0..5];
+            });
+        }
+    };
+    package Model3 {
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::Block';
+        sub BUILD
+        {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->layers(nn->Sequential());
+                $self->layers->add(map { nn->Dense($_ * 10) } 0..5);
+            });
+        }
+    };
+    package Model4 {
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::Block';
+        sub BUILD
+        {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->data({a => '4', b => 123});
+            });
+        }
+    };
+    my $w = 0;
+    local($SIG{__WARN__}) = sub {
+        $w++;
+    };
+    Model1->new->collect_params;
+    ok($w > 0); $w = 0;
+    Model2->new->collect_params;
+    ok($w > 0); $w = 0;
+    Model3->new->collect_params;
+    ok($w == 0); $w = 0;
+    Model4->new->collect_params;
+    ok($w == 0);
+}
+
+test_block_attr_list_of_block();
+
+sub check_sequential
+{
+    my ($net) = @_;
+    my $dense1 = nn->Dense(10);
+    $net->add($dense1);
+    my $dense2 = nn->Dense(10);
+    $net->add($dense2);
+    my $dense3 = nn->Dense(10);
+    $net->add($dense3);
+
+    ok(refaddr($net->[1]) == refaddr($dense2));
+    ok(refaddr($net->[-1]) == refaddr($dense3));
+    my $slc = $net->slice([1,2]);
+    ok(@$slc == 2 and refaddr($slc->[0]) == refaddr($dense2) and refaddr($slc->[1]) == refaddr($dense3));
+    ok(ref $slc eq ref $net);
+}
+
+sub test_sequential
+{
+    check_sequential(nn->Sequential());
+    check_sequential(nn->HybridSequential());
+}
+
+test_sequential();
+
+sub test_global_norm_clip
+{
+    my @stypes = ('default', 'row_sparse');
+    my $check_global_norm_clip = sub { my ($stype) = @_;
+        my $x1 = mx->nd->ones([3,3])->tostype($stype);
+        my $x2 = mx->nd->ones([4,4])->tostype($stype);
+        my $norm = gluon->utils->clip_global_norm([$x1, $x2], 1.0);
+        ok($norm == 5);
+        ok(almost_equal($x1->aspdl, mx->nd->ones([3,3])->aspdl/5));
+        ok(almost_equal($x2->aspdl, mx->nd->ones([4,4])->aspdl/5));
+
+        my $x3 = mx->nd->array([1.0, 2.0, 'nan'])->tostype($stype);
+        my $w = 0;
+        local($SIG{__WARN__}) = sub {
+            $w++;
+        };
+        gluon->utils->clip_global_norm([$x1, $x3], 2.0);
+        ok($w == 1);
+    };
+    for my $stype (@stypes)
+    {
+        $check_global_norm_clip->($stype);
+    }
+}
+
+test_global_norm_clip();
+
 sub test_embedding
 {
-    my $layer = gluon->nn->Embedding(10, 100);
-    $layer->initialize();
-    my $x = mx->nd->array([3,4,2,0,1]);
+    local($ENV{MXNET_STORAGE_FALLBACK_LOG_VERBOSE}) = 0;
+    my $check_embedding = sub { my ($sparse_grad) = @_;
+        my $layer = nn->Embedding(10, 100, sparse_grad=>$sparse_grad);
+        $layer->initialize();
+        my $x = mx->nd->array([3,4,2,0,1]); my $y;
+        mx->autograd->record(sub {
+            $y = $layer->($x);
+            $y->backward();
+        });
+        ok(($layer->weight->grad->aspdl->slice('X', [0, 4]) == 1)->all);
+        ok(($layer->weight->grad->aspdl->slice('X', [5, -1]) == 0)->all);
+    };
+    my $check_embedding_large_input = sub { my ($sparse_grad) = @_;
+        my $embedding = nn->Embedding(10, 1, sparse_grad=>$sparse_grad);
+        $embedding->initialize();
+        $embedding->hybridize();
+        my $shape = [20481];
+        my ($emb_in, $loss);
+        mx->autograd->record(sub {
+            $emb_in = $embedding->(mx->nd->ones($shape));
+            $loss = $emb_in->sum;
+        });
+        $loss->backward;
+        ok($embedding->weight->grad->sum->asscalar == 20481);
+    };
+    $check_embedding->(1);
+    $check_embedding->(0);
+    $check_embedding_large_input->(1);
+    $check_embedding_large_input->(0);
+}
+
+test_embedding();
+
+sub test_hybrid_stale_cache
+{
+    my $net = nn->HybridSequential();
+    $net->name_scope(sub {
+        $net->add(nn->Dense(10, weight_initializer=>'zeros', bias_initializer=>'ones', flatten=>0));
+    });
+
+    $net->hybridize();
+    $net->initialize();
+    $net->(mx->nd->ones([2,3,5]));
+
+    $net->add(nn->Flatten());
+    is_deeply($net->(mx->nd->ones([2,3,5]))->shape, [2, 30]);
+
+    $net = nn->HybridSequential();
+    $net->name_scope(sub {
+        $net->fc1(nn->Dense(10, weight_initializer=>'zeros',
+                                    bias_initializer=>'ones', flatten=>0));
+        $net->fc2(nn->Dense(10, weight_initializer=>'zeros',
+                                    bias_initializer=>'ones', flatten=>0));
+    });
+    $net->hybridize();
+    $net->initialize();
+    $net->(mx->nd->ones([2,3,5]));
+
+    $net->fc2(nn->Dense(10, weight_initializer=>'zeros',
+                                bias_initializer=>'ones', flatten=>1));
+    $net->initialize();
+    is_deeply($net->(mx->nd->ones([2,3,5]))->shape, [2, 10]);
+}
+
+test_hybrid_stale_cache();
+
+sub test_lambda
+{
+    my $net1 = nn->HybridSequential();
+    $net1->add(nn->Activation('tanh'),
+             nn->LeakyReLU(0.1));
+
+    my $net2 = nn->HybridSequential();
+    my $op3 = sub { my ($F, $x, @args) = @_; $F->LeakyReLU($x, @args, slope=>0.1); };
+    $net2->add(nn->HybridLambda('tanh'),
+             nn->HybridLambda($op3));
+
+    my $op4 = sub { mx->nd->LeakyReLU($_[0], slope=>0.1); };
+    my $net3 = nn->Sequential();
+    $net3->add(nn->Lambda('tanh'),
+             nn->Lambda($op4));
+
+    my $input_data = mx->nd->random->uniform(shape=>[2, 3, 5, 7]);
+    my ($out1, $out2, $out3) = ($net1->($input_data), $net2->($input_data), $net3->($input_data));
+    ok(almost_equal($out1->aspdl, $out2->aspdl, 1e-3));
+    ok(almost_equal($out1->aspdl, $out3->aspdl, 1e-3));
+}
+
+test_lambda();
+
+sub test_fill_shape_deferred
+{
+    my $net = nn->HybridSequential();
+    $net->name_scope(sub {
+        $net->add(nn->Conv2D(64, kernel_size=>2, padding=>1),
+                nn->BatchNorm(),
+                nn->Dense(10));
+    });
+    $net->hybridize();
+    $net->initialize();
+    $net->(mx->nd->ones([2,3,5,7]));
+    ok($net->[0]->weight->shape->[1] == 3);
+    ok($net->[1]->gamma->shape->[0] == 64);
+    ok($net->[2]->weight->shape->[1] == 3072);
+}
+
+test_fill_shape_deferred();
+
+sub test_fill_shape_load
+{
+    my $ctx = mx->context->current_context();
+    my $net1 = nn->HybridSequential();
+    $net1->name_scope(sub {
+        $net1->add(nn->Conv2D(64, kernel_size=>2, padding=>1),
+                 nn->BatchNorm(),
+                 nn->Dense(10))
+    });
+    $net1->hybridize();
+    $net1->initialize(mx->init->Uniform, ctx => $ctx);
+    $net1->(mx->nd->ones([2,3,5,7], ctx => $ctx));
+    $net1->save_parameters('net_fill.params');
+
+    my $net2 = nn->HybridSequential();
+    $net2->name_scope(sub {
+        $net2->add(nn->Conv2D(64, kernel_size=>2, padding=>1),
+                 nn->BatchNorm(),
+                 nn->Dense(10))
+    });
+    $net2->hybridize();
+    $net2->initialize();
+    $net2->load_parameters('net_fill.params', ctx=>$ctx);
+    ok($net2->[0]->weight->shape->[1] == 3);
+    ok($net2->[1]->gamma->shape->[0] == 64);
+    ok($net2->[2]->weight->shape->[1] == 3072);
+}
+
+test_fill_shape_load();
+
+use JSON::PP qw(decode_json);
+
+sub test_inline
+{
     my $y;
+
+    my $net = nn->HybridSequential();
+    $net->name_scope(sub {
+        $net->add(nn->Dense(10));
+        $net->add(nn->Dense(10));
+        $net->add(nn->Dense(10));
+    });
+    $net->initialize();
+
+    $net->hybridize(inline_limit=>3);
     mx->autograd->record(sub {
-        $y = $layer->($x);
-        $y->backward();
+        $y = $net->(mx->nd->zeros([1,10]));
+    });
+    my $len_1 = @{ decode_json(mx->autograd->get_symbol($y)->tojson())->{nodes} };
+    $y->backward();
+
+    $net->hybridize(inline_limit=>0);
+    mx->autograd->record(sub {
+        $y = $net->(mx->nd->zeros([1,10]));
     });
-    ok(($layer->weight->grad->slice([0,4]) == 1)->aspdl->all);
-    ok(($layer->weight->grad->slice([5, -1]) == 0)->aspdl->all);
+    my $len_2 = @{ decode_json(mx->autograd->get_symbol($y)->tojson())->{nodes} };
+    $y->backward();
+
+    is($len_1, $len_2 + 2);
 }
 
-test_embedding();
+test_inline();
+
+sub test_activations
+{
+    my $point_to_validate = mx->nd->array([(-0.1, 0.1) x 3]);
+
+    my $swish = nn->Swish();
+    my $swish_test = sub { my ($x) = @_;
+        return $x * mx->nd->sigmoid($x)
+    };
+
+    for(zip($swish_test->($point_to_validate), $swish->($point_to_validate)))
+    {
+        my ($test_point, $ref_point) = @$_;
+        ok($test_point == $ref_point);
+    }
+
+    my $elu = nn->ELU();
+    my $elu_test = sub { my ($x) = @_;
+        my $elu = sub { my ($x) = @_;
+            return $x < 0 ? 1.0 * (mx->nd->exp($x) - 1) : $x;
+        };
+        return [map { $elu->($_) } @{ $x }];
+    };
+
+    for(zip($elu_test->($point_to_validate), $elu->($point_to_validate)))
+    {
+        my ($test_point, $ref_point) = @$_;
+        ok($test_point == $ref_point);
+    }
+
+    my $selu = nn->SELU();
+    my $selu_test = sub { my ($x) = @_;
+        my $selu = sub { my ($x) = @_;
+            my ($scale, $alpha) = (1.0507009873554804934193349852946, 1.6732632423543772848170429916717);
+            return $x => 0 ? $scale * $x : $alpha * mx->nd->exp($x) - $alpha;
+        };
+        return [map { $selu->($_) } @{ $x }];
+    };
+
+    for(zip($selu_test->($point_to_validate), $selu->($point_to_validate)))
+    {
+        my ($test_point, $ref_point) = @$_;
+        ok($test_point == $ref_point);
+    }
+
+    my $prelu = nn->PReLU();
+    $prelu->initialize();
+    my $x = $point_to_validate->reshape([1, 3, 2]);
+    ok(almost_equal($prelu->($x)->aspdl, mx->nd->where($x >= 0, $x, 0.25 * $x)->aspdl));
+}
+
+test_activations();
+
+sub test_req
+{
+    my $data = mx->nd->random->uniform(shape=>[1,3,224,224]);
+    my $label = mx->nd->array([1]);
+    my $loss = gluon->loss->SoftmaxCrossEntropyLoss();
+
+    my $net = nn->HybridSequential();
+    my $net1 = nn->HybridSequential();
+    $net1->add(nn->Dense(4));
+    my $net2 = nn->HybridSequential();
+    $net2->add(nn->Dense(3));
+    $net2->add(nn->Dense(2));
+    $net->add($net1);
+    $net->add($net2);
+    $net->initialize();
+
+    $net->hybridize();
+
+    for my $v ($net->collect_params->values)
+    {
+        $v->grad_req('add');
+    }
+
+    $net->collect_params->zero_grad();
+    my $grad;
+    mx->autograd->record(sub {
+        my $pred = $net->($data);
+        my $l = $loss->($pred, $label);
+        $l->backward();
+        $grad = $net->[0][0]->weight->grad->mean->aspdl;
+        # run twice to check req = add
+        $pred = $net->($data);
+        $l = $loss->($pred, $label);
+        $l->backward;
+    });
+
+    my $grad_double = $net->[0][0]->weight->grad->mean->aspdl;
+    ok(almost_equal($grad * 2, $grad_double));
+}
+
+test_req();
+
+sub test_zero_grad
+{
+    my $data = mx->nd->random->uniform(shape=>[3,3]);
+    my $net = nn->Embedding(3, 4, sparse_grad=>1, prefix=>'test_zero_grad_');
+    $net->initialize();
+    mx->autograd->record(sub {
+        $net->($data)->backward;
+    });
+    $net->collect_params->zero_grad;
+    my $grad = $net->collect_params->{test_zero_grad_weight}->grad;
+    ok(almost_equal($grad->aspdl, $grad->aspdl * 0));
+}
+
+test_zero_grad();
+
+sub test_hook
+{
+    my $hook_call_count = 0;
+    my $pre_hook_call_count = 0;
+
+    my $call_hook = sub { my ($block, $x, $y) = @_;
+        $hook_call_count += 1;
+    };
+
+    my $call_pre_hook = sub { my ($block, $x) = @_;
+        $pre_hook_call_count += 1;
+    };
+
+    my $block = nn->Dense(10);
+    $block->initialize();
+    my $handle = $block->register_forward_hook($call_hook);
+    my $pre_handle = $block->register_forward_pre_hook($call_pre_hook);
+    $block->(mx->nd->ones([3, 5]));
+
+    ok($hook_call_count == 1);
+    ok($pre_hook_call_count == 1);
+
+    $handle->detach();
+    $block->(mx->nd->ones([3, 5]));
+
+    ok($hook_call_count == 1);
+    ok($pre_hook_call_count == 2);
+
+    $pre_handle->detach();
+    $block->(mx->nd->ones([3, 5]));
+
+    ok($hook_call_count == 1);
+    ok($pre_hook_call_count == 2);
+}
+
+test_hook();
+
+sub test_apply
+{
+    my @called_blocks;
+
+    my $record_name = sub { my ($block) = @_;
+        push @called_blocks, $block->name;
+    };
+    my $block = nn->HybridSequential(prefix=>'test_');
+    $block->name_scope(sub {
+        $block->add(nn->Dense(10));
+        $block->add(nn->Dropout(0.5));
+    });
+    $block->apply($record_name);
+
+    is_deeply(\@called_blocks, ['test_dense0', 'test_dropout0', 'test']);
+}
+
+test_apply();
+
+sub test_sparse_hybrid_block_grad
+{
+    package Embedding {
+        use AI::MXNet::Gluon::Mouse;
+        use AI::MXNet::Function::Parameters;
+        extends 'AI::MXNet::Gluon::HybridBlock';
+        has ['num_tokens', 'embedding_size'] => (is => 'rw');
+        method python_constructor_arguments() { ['num_tokens', 'embedding_size'] }
+        sub BUILD {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->embedding(nn->Embedding(
+                    $self->num_tokens, $self->embedding_size, sparse_grad=>1
+                ));
+            });
+        }
+
+        method hybrid_forward($F, $words)
+        {
+            my $emb = $self->embedding->($words);
+            return $emb + $F->ones_like($emb);
+        }
+    };
+    my $embedding = Embedding->new(20, 3);
+    $embedding->initialize();
+    $embedding->hybridize();
+
+    my $loss;
+    mx->autograd->record(sub {
+        my $emb0 = $embedding->(mx->nd->arange(stop => 10))->sum;
+        my $emb1 = $embedding->(mx->nd->arange(stop => 10))->sum;
+        $loss = $emb0 + $emb1;
+    });
+    $loss->backward();
+    my $grad = $embedding->embedding->weight->grad->aspdl;
+    ok(($grad->slice('X', ':9') == 2)->all);
+    ok(($grad->slice('X', '10:') == 0)->all);
+}
+
+test_sparse_hybrid_block_grad();
+
+sub test_sparse_hybrid_block
+{
+    package Linear {
+        use AI::MXNet::Gluon::Mouse;
+        use AI::MXNet::Function::Parameters;
+        extends 'AI::MXNet::Gluon::HybridBlock';
+        has ['units'] => (is => 'rw');
+        method python_constructor_arguments() { ['units'] }
+        sub BUILD {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->w($self->params->get(
+                    'w', shape => [$self->units, $self->units]
+                ));
+            });
+        }
+        method hybrid_forward($F, $x, :$w)
+        {
+            return $F->dot($x, $w);
+        }
+    };
+    package SparseBlock {
+        use AI::MXNet::Gluon::Mouse;
+        use AI::MXNet::Function::Parameters;
+        extends 'AI::MXNet::Gluon::HybridBlock';
+        has ['units'] => (is => 'rw');
+        method python_constructor_arguments() { ['units'] }
+        sub BUILD {
+            my $self = shift;
+            $self->name_scope(sub {
+                $self->net(Linear->new($self->units));
+            });
+        }
+        method hybrid_forward($F, $x)
+        {
+            return $self->net->($x) * $x;
+        }
+    };
+    my $block = SparseBlock->new(2);
+    $block->initialize();
+    $block->hybridize();
+    my $x = mx->nd->ones([2,2])->tostype('csr');
+    my $z;
+    mx->autograd->record(sub {
+        $z = $block->($x) + $block->($x);
+    });
+    $z->backward;
+    ok(($block->net->w->grad->aspdl == 4)->all);
+}
+
+test_sparse_hybrid_block();
diff --git a/perl-package/AI-MXNet/t/test_gluon_trainer.t b/perl-package/AI-MXNet/t/test_gluon_trainer.t
new file mode 100644
index 0000000..8b3b52b
--- /dev/null
+++ b/perl-package/AI-MXNet/t/test_gluon_trainer.t
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+use strict;
+use warnings;
+use Test::More tests => 30;
+use AI::MXNet qw(mx);
+use AI::MXNet::Gluon qw(gluon);
+use AI::MXNet::Gluon::NN qw(nn);
+use AI::MXNet::TestUtils qw(almost_equal dies_ok);
+use Scalar::Util qw(refaddr);
+use AI::MXNet::Base;
+
+sub test_multi_trainer
+{
+    my $x = gluon->Parameter('x', shape=>[10], stype=>'row_sparse');
+    $x->initialize();
+    # test set trainer
+    my $trainer0 = gluon->Trainer([$x], 'sgd');
+    ok(refaddr($x->_trainer) == refaddr($trainer0));
+    # test unset trainer
+    $x->_set_trainer(undef);
+    ok(not defined $x->_trainer);
+    $x->_set_trainer($trainer0);
+    # multiple trainers for a sparse Parameter are not allowed
+    dies_ok(sub { gluon->Trainer([$x], 'sgd') });
+}
+
+sub test_trainer
+{
+    my $dict_equ = sub { my ($a, $b) = @_;
+        is_deeply({ map { $_ => 1 } keys %$a }, { map { $_ => 1 } keys %$b });
+        for my $k (keys %$a)
+        {
+            ok(($a->{$k}->aspdl == $b->{$k}->aspdl)->all);
+        }
+    };
+    my $x = gluon->Parameter('x', shape=>[10]);
+    $x->initialize(ctx=>[mx->cpu(0), mx->cpu(1)], init=>'zeros');
+    my $trainer = gluon->Trainer([$x], 'sgd', {'learning_rate'=> 1.0, 'momentum'=> 0.5});
+    my $y;
+    mx->autograd->record(sub {
+        for my $w (@{ $x->list_data() })
+        {
+            $y = $w + 1;
+            $y->backward();
+        }
+    });
+    $trainer->step(1);
+
+    ok(($x->data(mx->cpu(1))->aspdl == -2)->all);
+
+    $x->lr_mult(0.5);
+
+    mx->autograd->record(sub {
+        for my $w (@{ $x->list_data() })
+        {
+            $y = $w + 1;
+            $y->backward();
+        }
+    });
+    $trainer->step(1);
+
+    ok(($x->data(mx->cpu(1))->aspdl == -4)->all);
+
+    $trainer->save_states('test_trainer.states');
+    my $states;
+    if($trainer->update_on_kvstore)
+    {
+        $states = { %{ $trainer->kvstore->_updater->states } };
+    }
+    else
+    {
+        $states = { %{ $trainer->_updaters->[0]->states } };
+    }
+    $trainer->load_states('test_trainer.states');
+    if($trainer->update_on_kvstore)
+    {
+        $dict_equ->($trainer->kvstore->_updater->states, $states);
+        ok($trainer->_optimizer eq $trainer->kvstore->_updater->optimizer);
+    }
+    else
+    {
+        for my $updater (@{ $trainer->_updaters })
+        {
+            $dict_equ->($updater->states, $states);
+        }
+        ok($trainer->_optimizer eq $trainer->_updaters->[0]->optimizer);
+    }
+
+    dies_ok(sub { $trainer->update(1 ) });
+    dies_ok(sub { $trainer->allreduce_grads() });
+
+    $x = gluon->Parameter('x', shape=>[10]);
+    $x->initialize(ctx=>[mx->cpu(0), mx->cpu(1)], init=>'zeros');
+    my $trainer2 = gluon->Trainer([$x], 'sgd', {learning_rate => 1.0, momentum => 0.5},
+                             update_on_kvstore=>0);
+    mx->autograd->record(sub {
+        for(enumerate($x->list_data))
+        {
+            my ($i, $w) = @$_;
+            my $y = $i*$w;
+            $y->backward;
+        }
+    });
+    ok(($x->grad(mx->cpu(0))->aspdl != $x->grad(mx->cpu(1))->aspdl)->all);
+    $trainer2->allreduce_grads;
+    ok(($x->grad(mx->cpu(0))->aspdl == $x->grad(mx->cpu(1))->aspdl)->all);
+    $trainer2->update(1);
+    ok(($x->data(mx->cpu(1))->aspdl == -1)->all);
+
+}
+
+test_trainer();
+
+sub test_trainer_save_load
+{
+    my $x = gluon->Parameter('x', shape=>[10], lr_mult=>1.0);
+    $x->initialize(ctx=>[mx->cpu(0), mx->cpu(1)], init=>'zeros');
+    my $trainer = gluon->Trainer([$x], 'sgd', {learning_rate => 0.1});
+    mx->autograd->record(sub {
+        for my $w (@{ $x->list_data })
+        {
+            my $y = $w + 1;
+            $y->backward();
+        }
+    });
+    $trainer->step(1);
+    ok($trainer->kvstore->_updater->optimizer->_get_lr(0) == 0.1);
+    $trainer->save_states('test_trainer_save_load.states');
+    $trainer->load_states('test_trainer_save_load.states');
+    $x->lr_mult(2.0);
+    # check if parameter dict is correctly associated with optimizer after load_state
+    ok($trainer->kvstore->_updater->optimizer->_get_lr(0) == 0.2);
+}
+
+test_trainer_save_load();
+
+sub test_trainer_multi_layer_init
+{
+    local($ENV{MXNET_STORAGE_FALLBACK_LOG_VERBOSE}) = 0;
+    package Net {
+        use AI::MXNet::Gluon::Mouse;
+        extends 'AI::MXNet::Gluon::Block';
+        use AI::MXNet::Function::Parameters;
+        sub BUILD {
+            my $self = shift;
+            $self->name_scope(sub {
+                # sparse param
+                $self->embed_weight($self->params->get('embed_weight', stype=>'row_sparse',
+                                                    shape=>[4,3], grad_stype=>'row_sparse'));
+                # dense param from a hybrid block
+                $self->dense0(nn->Dense(2));
+            });
+        }
+        method forward($x)
+        {
+            my $embed_weight = $self->embed_weight->row_sparse_data($x);
+            my $embed = mx->nd->Embedding(data=>$x, weight=>$embed_weight,
+                                    input_dim=>4, output_dim=>3, sparse_grad=>1);
+            return $self->dense0->($embed);
+        }
+    };
+    my $check_init = sub { my ($ctxes) = @_;
+        my $net = Net->new(prefix=>'net_');
+        $net->initialize(mx->init->One(), ctx=>$ctxes);
+        my $trainer = gluon->Trainer($net->collect_params(), 'sgd', {learning_rate => 1});
+        my $data = mx->nd->array([[0,2], [1,2]]);
+        my $xs = gluon->utils->split_and_load($data, ctx_list => $ctxes);
+        my @ys;
+        mx->autograd->record(sub {
+            for my $x (@{ $xs })
+            {
+                my $y = $net->($x);
+                push @ys, $y;
+            }
+        });
+        for my $y (@ys)
+        {
+            $y->backward;
+        }
+        $trainer->step(1);
+        # all parameters should be initialized
+        ok(not @{ $trainer->_params_to_init });
+        my $all_rows = mx->nd->arange(start => 0, stop => 4, ctx=>mx->cpu(1));
+        # check the updated weights
+        my $weight = $net->embed_weight->row_sparse_data($all_rows)->aspdl;
+        ok(($weight->at(0) == -1)->all);
+        ok(($weight->at(1) == -1)->all);
+        ok(($weight->at(2) == -3)->all);
+        ok(($weight->at(3) ==  1)->all);
+    };
+    $check_init->([mx->cpu(1), mx->cpu(2)]);
+    $check_init->([mx->cpu(1)]);
+}
+
+test_trainer_multi_layer_init();
+
+sub test_trainer_reset_kv
+{
+    my $check_trainer_reset_kv = sub { my ($kv) = @_;
+        my $params = gluon->ParameterDict();
+        my $x = $params->get('x', shape=>[10], lr_mult=>1.0);
+        $params->initialize(ctx=>[mx->cpu(0), mx->cpu(1)], init=>'zeros');
+        my $trainer = gluon->Trainer($params, 'sgd', {learning_rate => 0.1}, kvstore=>$kv);
+        $params->save('test_trainer_reset_kv.params');
+        mx->autograd->record(sub {
+            for my $w (@{ $x->list_data })
+            {
+                my $y = $w + 1;
+                $y->backward;
+            }
+        });
+        $trainer->step(1);
+        is($trainer->kvstore->type, $kv);
+        # load would reset kvstore
+        $params->load('test_trainer_reset_kv.params', ctx => [mx->cpu(0), mx->cpu(1)]);
+        ok(not defined $trainer->kvstore);
+        ok (defined $trainer->_kv_initialized and not $trainer->_kv_initialized);
+        mx->autograd->record(sub {
+            for my $w (@{ $x->list_data })
+            {
+                my $y = $w + 1;
+                $y->backward;
+            }
+        });
+        $trainer->step(1);
+        # the updated parameter should be based on the loaded checkpoint
+        ok(($x->data(mx->cpu()) == -0.2)->aspdl->all);
+    };
+    my @kvs = ('local', 'device');
+    for my $kv (@kvs)
+    {
+        $check_trainer_reset_kv->($kv);
+    }
+}
+
+test_trainer_reset_kv();
+
diff --git a/perl-package/AI-MXNet/t/test_loss.t b/perl-package/AI-MXNet/t/test_loss.t
index 58f7b27..3a136f4 100644
--- a/perl-package/AI-MXNet/t/test_loss.t
+++ b/perl-package/AI-MXNet/t/test_loss.t
@@ -25,6 +25,7 @@ use Hash::Ordered;
 
 sub test_loss_ndarray
 {
+    mx->random->seed(1234);
     my $output     = mx->nd->array([1, 2, 3, 4]);
     my $label      = mx->nd->array([1, 3, 5, 7]);
     my $weighting  = mx->nd->array([0.5, 1, 0.5, 1]);
@@ -72,6 +73,7 @@ sub get_net
 
 sub test_ce_loss
 {
+    mx->random->seed(1234);
     my $nclass = 10;
     my $N = 20;
     my $data = mx->random->uniform(-1, 1, shape=>[$N, $nclass]);
@@ -93,6 +95,7 @@ test_ce_loss();
 
 sub test_bce_loss
 {
+    mx->random->seed(1234);
     my $N = 20;
     my $data = mx->random->uniform(-1, 1, shape=>[$N, 20]);
     my $label = mx->nd->array([qw/1 1 0 1 0 0 0 1 1 1 1 1 0 0 1 0 0 0 0 0/], dtype=>'float32');
@@ -114,6 +117,7 @@ test_bce_loss();
 
 sub test_bce_equal_ce2
 {
+    mx->random->seed(1234);
     my $N = 100;
     my $loss1 = gluon->loss->SigmoidBCELoss(from_sigmoid=>1);
     my $loss2 = gluon->loss->SoftmaxCELoss(from_logits=>1);
@@ -127,6 +131,7 @@ test_bce_equal_ce2();
 
 sub test_kl_loss
 {
+    mx->random->seed(1234);
     my $N = 20;
     my $data = mx->random->uniform(-1, 1, shape=>[$N, 10]);
     my $label = mx->nd->softmax(mx->random->uniform(0, 1, shape=>[$N, 2]));
@@ -147,6 +152,7 @@ test_kl_loss();
 
 sub test_l2_loss
 {
+    mx->random->seed(1234);
     my $N = 20;
     my $data = mx->random->uniform(-1, 1, shape=>[$N, 10]);
     my $label = mx->nd->softmax(mx->random->uniform(-1, 1, shape=>[$N, 1]));
@@ -167,6 +173,7 @@ test_l2_loss();
 
 sub test_l1_loss
 {
+    mx->random->seed(1234);
     my $N = 20;
     my $data = mx->random->uniform(-1, 1, shape=>[$N, 10]);
     my $label = mx->nd->softmax(mx->random->uniform(-1, 1, shape=>[$N, 1]));
@@ -187,6 +194,7 @@ test_l1_loss();
 
 sub test_ctc_loss
 {
+    mx->random->seed(1234);
     my $loss = gluon->loss->CTCLoss();
     my $l = $loss->(mx->nd->ones([2,20,4]), mx->nd->array([[1,0,-1,-1],[2,1,1,-1]]));
     ok(almost_equal($l->aspdl, mx->nd->array([18.82820702, 16.50581741])->aspdl));
@@ -216,6 +224,7 @@ test_ctc_loss();
 
 sub test_ctc_loss_train
 {
+    mx->random->seed(1234);
     my $N = 20;
     my $data = mx->random->uniform(-1, 1, shape=>[$N, 20, 10]);
     my $label = mx->nd->arange(start => 4, repeat=>$N)->reshape([$N, 4]);
@@ -227,7 +236,7 @@ sub test_ctc_loss_train
     $loss = mx->sym->make_loss($loss);
     local($AI::MXNet::Logging::silent) = 1;
     my $mod = mx->mod->Module($loss, data_names=>['data'], label_names=>['label']);
-    $mod->fit($data_iter, num_epoch=>200, optimizer_params=>{learning_rate => 1},
+    $mod->fit($data_iter, num_epoch=>200, optimizer_params=>{learning_rate => 0.01},
             initializer=>mx->init->Xavier(magnitude=>2), eval_metric=>mx->metric->Loss(),
             optimizer=>'adam');
     ok($mod->score($data_iter, mx->metric->Loss())->{loss} < 20);
@@ -237,6 +246,7 @@ test_ctc_loss_train();
 
 sub test_sample_weight_loss
 {
+    mx->random->seed(1234);
     my $nclass = 10;
     my $N = 20;
     my $data = mx->random->uniform(-1, 1, shape=>[$N, $nclass]);
diff --git a/perl-package/AI-MXNet/t/test_ndarray.t b/perl-package/AI-MXNet/t/test_ndarray.t
index aa8d120..1190d52 100644
--- a/perl-package/AI-MXNet/t/test_ndarray.t
+++ b/perl-package/AI-MXNet/t/test_ndarray.t
@@ -18,27 +18,30 @@
 use strict;
 use warnings;
 use AI::MXNet qw(mx);
-use AI::MXNet::TestUtils qw(almost_equal same);
-use Test::More tests => 20;
+use AI::MXNet::TestUtils qw(almost_equal same rand_ndarray randint zip);
+use Test::More tests => 251;
 use PDL;
+use File::Temp qw(tempdir);
+use IO::File;
 
 sub test_ndarray_reshape
 {
-    my $tensor = mx->nd->array([[[1, 2], [3, 4]],
-                                [[5, 6], [7, 8]]]);
-    my $true_res = mx->nd->arange(stop => 8) + 1;
-    is_deeply($tensor->reshape([-1])->aspdl->unpdl, $true_res->aspdl->unpdl);
-    $true_res  = mx->nd->array([[1, 2, 3, 4],
-                                [5, 6, 7, 8]]);
-    is_deeply($tensor->reshape([2, -1])->aspdl->unpdl, $true_res->aspdl->unpdl);
-    $true_res  = mx->nd->array([[1, 2],
-                                [3, 4],
-                                [5, 6],
-                                [7, 8]]);
-    is_deeply($tensor->reshape([-1, 2])->aspdl->unpdl, $true_res->aspdl->unpdl);
+    my $tensor = (mx->nd->arange(stop => 30) + 1)->reshape([2, 3, 5]);
+    my $true_res = mx->nd->arange(stop => 30) + 1;
+    ok(same($tensor->reshape([-1])->aspdl, $true_res->aspdl));
+    ok(same($tensor->reshape([2, -1])->aspdl, $true_res->reshape([2, 15])->aspdl));
+    ok(same($tensor->reshape([0, -1])->aspdl, $true_res->reshape([2, 15])->aspdl));
+    ok(same($tensor->reshape([-1, 2])->aspdl, $true_res->reshape([15, 2])->aspdl));
+    ok(same($tensor->reshape([6, 5])->aspdl, $true_res->reshape([6, 5])->aspdl));
+    ok(same($tensor->reshape([30])->aspdl, $true_res->aspdl));
+    ok(same($tensor->reshape([-1, 6])->aspdl, $true_res->reshape([5, 6])->aspdl));
+    ok(same($tensor->reshape([-2])->aspdl, $true_res->reshape([2, 3, 5])->aspdl));
+    ok(same($tensor->reshape([-3, -1])->aspdl, $true_res->reshape([6, 5])->aspdl));
+    ok(same($tensor->reshape([-1, 15])->reshape([0, -4, 3, -1])->aspdl, $true_res->reshape([2, 3, 5])->aspdl));
+    ok(same($tensor->reshape([-1, 0])->aspdl, $true_res->reshape([10, 3])->aspdl));
+    ok(same($tensor->reshape([-1, 0], reverse=>1)->aspdl, $true_res->reshape([6, 5])->aspdl));
 }
 
-
 sub test_moveaxis
 {
     my $X = mx->nd->array([[[1, 2, 3], [4, 5, 6]],
@@ -169,6 +172,51 @@ sub test_image_to_tensor
     );
 }
 
+sub test_buffer_load
+{
+    my $nrepeat = 10;
+    my $tmpdir = tempdir(CLEANUP => 1);
+    for my $repeat (1..$nrepeat)
+    {
+        # test load_buffer as list
+        my @data;
+        for(1..10)
+        {
+            push @data, rand_ndarray([randint(1, 5)], 'default');
+        }
+        my $fname = "$tmpdir/list_$repeat.param";
+        mx->nd->save($fname, \@data);
+        my $buf_data = join('',IO::File->new($fname)->getlines);
+        my $data2 = mx->nd->load_frombuffer($buf_data);
+        ok(@data == @$data2);
+        zip(sub {
+            my ($x, $y) = @_;
+            ok(same($x->aspdl, $y->aspdl));
+        }, \@data, $data2);
+        # test load_buffer as hash
+        my $i = 0;
+        my %hash = map { 'ndarray xx '.$i++ => $_ } @data;
+        $fname = "$tmpdir/hash_$repeat.param";
+        mx->nd->save($fname, \%hash);
+        $buf_data = join('',IO::File->new($fname)->getlines);
+        my $hash2 = mx->nd->load_frombuffer($buf_data);
+        ok(%hash == %$hash2);
+        while(my ($k, $v) = each %hash)
+        {
+            ok(same($v->aspdl, $hash2->{$k}->aspdl));
+        }
+    }
+}
+
+sub test_histogram
+{
+    my $z = mx->nd->array([0..99]);
+    my $b = mx->nd->array([10, 20, 30, 60]);
+    my ($hist, $bins) = @{ mx->nd->histogram($z, bins => $b) };
+    ok(same($hist->aspdl, pdl([10, 10, 31])));
+    ok(same($bins->aspdl, pdl([10, 20, 30, 60])));
+}
+
 test_ndarray_slice();
 test_ndarray_reshape();
 test_moveaxis();
@@ -176,4 +224,5 @@ test_output();
 test_cached();
 test_linalg_gemm2();
 test_image_to_tensor();
-
+test_buffer_load();
+test_histogram();
diff --git a/perl-package/AI-MXNet/t/test_optimizers.t b/perl-package/AI-MXNet/t/test_optimizers.t
index 07f362a..af3e54e 100644
--- a/perl-package/AI-MXNet/t/test_optimizers.t
+++ b/perl-package/AI-MXNet/t/test_optimizers.t
@@ -638,7 +638,7 @@ method update($index, $weight, $grad, $state)
 
 package main;
 use Carp;
-use Test::More tests => 7884;
+use Test::More tests => 7992;
 use AI::MXNet::Base;
 use PDL::NiceSlice;
 use AI::MXNet::TestUtils qw(same reldiff almost_equal rand_ndarray);
@@ -1075,6 +1075,7 @@ sub test_adagrad
                         if(($wd_option->{wd}//0) == 0)
                         {
                             compare_optimizer($opt1->new(%kwarg), $opt2->new(%kwarg), $shape, $dtype, 'row_sparse', 'row_sparse');
+                            compare_optimizer($opt1->new(%kwarg), $opt2->new(%kwarg), $shape, $dtype, 'default', 'row_sparse');
                         }
                     }
                 }
diff --git a/perl-package/AI-MXNet/t/test_random.t b/perl-package/AI-MXNet/t/test_random.t
index 6f275b5..542f79c 100644
--- a/perl-package/AI-MXNet/t/test_random.t
+++ b/perl-package/AI-MXNet/t/test_random.t
@@ -17,7 +17,7 @@
 
 use strict;
 use warnings;
-use Test::More tests => 505;
+use Test::More tests => 506;
 use AI::MXNet qw(mx);
 use AI::MXNet::TestUtils qw(same enumerate);
 
@@ -225,3 +225,13 @@ sub test_sample_multinomial
 
 test_sample_multinomial();
 
+sub test_seed_context
+{
+    ## only checking perl/swig interaction
+    ## c++ implementation is tested on python's side thoroughly already
+    mx->random->seed(1234);
+    mx->random->seed(1234, ctx => mx->cpu(0));
+    ok(1);
+}
+
+test_seed_context();
diff --git a/perl-package/AI-MXNet/t/test_symbol.t b/perl-package/AI-MXNet/t/test_symbol.t
index e102246..09bab2f 100644
--- a/perl-package/AI-MXNet/t/test_symbol.t
+++ b/perl-package/AI-MXNet/t/test_symbol.t
@@ -17,7 +17,7 @@
 
 use strict;
 use warnings;
-use Test::More tests => 101;
+use Test::More tests => 103;
 use AI::MXNet qw(mx);
 use AI::MXNet::TestUtils qw(mlp2 conv check_consistency zip assert enumerate almost_equal same);
 use Storable qw(freeze thaw);
@@ -279,6 +279,17 @@ sub test_image_to_tensor
 
 test_image_to_tensor();
 
+sub test_histogram
+{
+    my $z = mx->nd->array([0..99]);
+    my $b = mx->nd->array([10, 20, 30, 60]);
+    my ($hist, $bins) = @{ mx->sym->histogram(mx->sym->var("z"), bins => mx->sym->var("bins"))->eval(args => { z => $z, bins => $b }) };
+    ok(same($hist->aspdl, pdl([10, 10, 31])));
+    ok(same($bins->aspdl, pdl([10, 20, 30, 60])));
+}
+
+test_histogram();
+
 __DATA__
 {
   "nodes": [
diff --git a/perl-package/AI-MXNetCAPI/Changes b/perl-package/AI-MXNetCAPI/Changes
index 30426a5..8dad8b4 100644
--- a/perl-package/AI-MXNetCAPI/Changes
+++ b/perl-package/AI-MXNetCAPI/Changes
@@ -1,5 +1,8 @@
 Revision history for Perl extension AI::MXNetCAPI
 
+1.3     Tue Jun 26 20:57:40 PDT 2018
+        - Major update, Gluon interface updated to parity with Python's API
+
 1.2     Sun Mar  4 16:29:19 PST 2018
         - Support for sparse tensors
 
diff --git a/perl-package/AI-MXNetCAPI/META.json b/perl-package/AI-MXNetCAPI/META.json
index e194db9..35271e3 100644
--- a/perl-package/AI-MXNetCAPI/META.json
+++ b/perl-package/AI-MXNetCAPI/META.json
@@ -37,5 +37,5 @@
       }
    },
    "release_status" : "stable",
-   "version" : "1.2"
+   "version" : "1.3"
 }
diff --git a/perl-package/AI-MXNetCAPI/META.yml b/perl-package/AI-MXNetCAPI/META.yml
index fa0bd13..48760da 100644
--- a/perl-package/AI-MXNetCAPI/META.yml
+++ b/perl-package/AI-MXNetCAPI/META.yml
@@ -19,4 +19,4 @@ no_index:
     - inc
 requires:
   Test::More: '0'
-version: '1.2'
+version: '1.3'
diff --git a/perl-package/AI-MXNetCAPI/README b/perl-package/AI-MXNetCAPI/README
index dbed0c2..dca8b4a 100644
--- a/perl-package/AI-MXNetCAPI/README
+++ b/perl-package/AI-MXNetCAPI/README
@@ -1,4 +1,4 @@
-AI-MXNetCAPI version 1.2
+AI-MXNetCAPI version 1.3
 =====================
 
 Swig interface to MXNet c api.
diff --git a/perl-package/AI-MXNetCAPI/lib/AI/MXNetCAPI.pm b/perl-package/AI-MXNetCAPI/lib/AI/MXNetCAPI.pm
index ef51539..b578507 100644
--- a/perl-package/AI-MXNetCAPI/lib/AI/MXNetCAPI.pm
+++ b/perl-package/AI-MXNetCAPI/lib/AI/MXNetCAPI.pm
@@ -18,7 +18,7 @@
 package AI::MXNetCAPI;
 use base qw(DynaLoader);
 bootstrap AI::MXNetCAPI;
-our $VERSION = '1.2';
+our $VERSION = '1.3';
 1;
 __END__
 
diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i
index 517a6da..2540e1b 100644
--- a/perl-package/AI-MXNetCAPI/mxnet.i
+++ b/perl-package/AI-MXNetCAPI/mxnet.i
@@ -284,6 +284,12 @@ const char *MXGetLastError();
  */
 int MXRandomSeed(int seed);
 /*!
+ * \brief Seed the global random number generator of the given device.
+ * \param seed the random number seed.
+ * \return 0 when success, -1 when failure happens.
+ */
+int MXRandomSeedContext(int seed, int dev_type, int dev_id);
+/*!
  * \brief Notify the engine about a shutdown,
  *  This can help engine to print less messages into display.
  *
@@ -322,6 +328,21 @@ int MXSetNumOMPThreads(int thread_num);
  */
 int MXGetVersion(int *out);
 
+/*!
+ * \brief set bulk execution limit
+ * \param bulk_size new bulk_size
+ * \param prev_bulk_size previous bulk_size
+ */
+int MXEngineSetBulkSize(int bulk_size, int* out);
+
+/*!
+ * \brief Get the number of GPUs.
+ * \param pointer to int that will hold the number of GPUs available.
+ * \return 0 when success, -1 when failure happens.
+ */
+int MXGetGPUCount(int* out);
+
+
 //-------------------------------------
 // Part 1: NDArray creation and deletion
 //-------------------------------------
@@ -447,6 +468,28 @@ int MXNDArrayLoad(const char* fname,
                             NDArrayHandle** out_array,
                             mx_uint *out_size,
                             const char*** out_array);
+
+/*!
+ * \brief Load list / dictionary of narrays from file content loaded into memory.
+ * This will load a list of ndarrays in a similar
+ * manner to MXNDArrayLoad, however, it loads from
+ * buffer containing the contents of a file, rather than
+ * from a specified file.
+ * \param ndarray_buffer pointer to the start of the ndarray file content
+ * \param size size of the file
+ * \param out_size number of narray loaded.
+ * \param out_arr head of the returning narray handles.
+ * \param out_name_size size of output name arrray.
+ * \param out_names the names of returning NDArrays, can be NULL
+ * \return 0 when success, -1 when failure happens
+ */
+int MXNDArrayLoadFromBuffer(const void *in,
+                            size_t size,
+                            mx_uint *out_size,
+                            NDArrayHandle** out_array,
+                            mx_uint *out_size,
+                            const char*** out_array);
+
 /*!
  * \brief Perform a synchronize copy from a continugous CPU memory region.
  *
@@ -558,6 +601,20 @@ int MXNDArrayReshape(NDArrayHandle handle,
                                int *in,
                                NDArrayHandle *out);
 /*!
+ * \brief Reshape the NDArray.
+ * \param handle the handle to the narray
+ * \param ndim number of dimensions of new shape
+ * \param dims new shape
+ * \param out the NDArrayHandle of reshaped NDArray
+ * \return 0 when success, -1 when failure happens
+ */
+int MXNDArrayReshape64(NDArrayHandle handle,
+                                 int ndim,
+                                 dim_t *in,
+                                 bool reverse,
+                                 NDArrayHandle *out);
+
+/*!
  * \brief get the shape of the array
  * \param handle the handle to the ndarray
  * \param out_dim the output dimension
@@ -874,6 +931,14 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out);
   */
 int MXCreateCachedOp(SymbolHandle handle,
                                 CachedOpHandle *out);
+/*!
+ * \brief create cached operator
+ */
+int MXCreateCachedOpEx(SymbolHandle handle,
+                                 int num_flags,
+                                 const char** keys,
+                                 const char** vals,
+                                 CachedOpHandle *out);
  /*!
   * \brief free cached operator
   */
@@ -1488,6 +1553,47 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
 );
 
 /*!
+ * \brief Return a new executor with the same symbol and shared memory,
+ * but different input/output shapes.
+ *
+ * \param partial_shaping Whether to allow changing the shape of unspecified arguments.
+ * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original.
+ * \param dev_type device type of default context
+ * \param dev_id device id of default context
+ * \param num_map_keys size of group2ctx map
+ * \param map_keys keys of group2ctx map
+ * \param map_dev_types device type of group2ctx map
+ * \param map_dev_ids device id of group2ctx map
+ * \param num_in_args length of in_args
+ * \param in_args in args array
+ * \param arg_grads arg grads handle array
+ * \param num_aux_states length of auxiliary states
+ * \param aux_states auxiliary states array
+ * \param shared_exec input executor handle for memory sharing
+ * \param out output executor handle
+ * \return a new executor
+ */
+int MXExecutorReshape(int partial_shaping,
+                                int allow_up_sizing,
+                                int dev_type,
+                                int dev_id,
+                                mx_uint num_map_keys,
+                                const char** in,
+                                const int* in,
+                                const int* in,
+                                const mx_uint num_provided_arg_shapes,
+                                const char** in,
+                                const mx_uint* in,
+                                const mx_uint* in,
+                                mx_uint* couple_out_size,
+                                NDArrayHandle** out_first_array,
+                                NDArrayHandle** out_second_array,
+                                mx_uint* out_size,
+                                NDArrayHandle** out_array,
+                                ExecutorHandle shared_exec,
+                                ExecutorHandle *out);
+
+/*!
  * \brief set a call back to notify the completion of operation
  */
 int MXExecutorSetMonitorCallback(ExecutorHandle handle,
diff --git a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i
index 2c11388..4d9177a 100644
--- a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i
+++ b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i
@@ -161,6 +161,7 @@
                          (mx_uint *out_size, const char ***out_array) (mx_uint temp_size, char** temp)
 {
     $1 = &temp_size;
+    *$1 = 0;
     $2 = &temp;
 }
 
@@ -188,6 +189,7 @@
 %typemap(in,numinputs=0) (mx_uint *out_size, const char ***out_array2) (mx_uint temp_size, char** temp)
 {
     $1 = &temp_size;
+    *$1 = 0;
     $2 = &temp;
 }
 
@@ -300,6 +302,37 @@
     Safefree($1);
 }
 
+%typemap(in) (dim_t *in)
+{
+    AV *tempav;
+    int i;
+    SV  **tv;
+    int av_len; 
+    if (!SvROK($input))
+        croak("Argument $argnum is not a reference.");
+        if (SvTYPE(SvRV($input)) != SVt_PVAV)
+        croak("Argument $argnum is not an array.");
+        tempav = (AV*)SvRV($input);
+    av_len = av_len(tempav) + 1;
+    if(av_len)
+    {
+        $1 = (dim_t *)safemalloc(av_len*sizeof(dim_t));
+        for (i = 0; i < av_len; i++) {
+            tv = av_fetch(tempav, i, 0);
+            $1[i] = (dim_t)SvIV(*tv);
+        }
+    }
+    else
+    {
+       $1 = NULL;
+    }
+
+}
+
+%typemap(freearg) (dim_t *in) {
+    Safefree($1);
+}
+
 %typemap(in) (NDArrayHandle* in), (SymbolHandle* in)
 {
     AV *tempav;
@@ -449,6 +482,7 @@
 %typemap(in,numinputs=0) (char const **out_array, size_t *out_size) (char * temp, size_t temp_size)
 {
     $2 = &temp_size;
+    *$2 = 0;
     $1 = &temp;
 }
 
@@ -465,6 +499,7 @@
 %typemap(in,numinputs=0) (size_t *out_size, char const **out_array) (size_t temp_size, char *temp)
 {
     $1 = &temp_size;
+    *$1 = 0;
     $2 = &temp;
 }
 
@@ -508,6 +543,7 @@
 {
     $1 = &temp1;
     $2 = &temp2;
+    *$2 = 0;
 }
 
 %typemap(argout) (uint64_t **out_index, uint64_t *out_size)
@@ -536,6 +572,7 @@
                          (mx_uint *out_size, NDArrayHandle** out_array) (mx_uint temp_size, NDArrayHandle* temp)
 {
     $1 = &temp_size;
+    *$1 = 0;
     $2 = &temp;
 }
 
@@ -616,6 +653,36 @@
     }
 }
 
+%typemap(in,numinputs=0) (mx_uint* couple_out_size, NDArrayHandle** out_first_array, NDArrayHandle** out_second_array)
+                         (mx_uint t, NDArrayHandle* t1, NDArrayHandle* t2)
+{
+    $1 = &t;
+    *$1 = 0;
+    $2 = &t1;
+    $3 = &t2;
+}
+
+%typemap(argout) (mx_uint* couple_out_size, NDArrayHandle** out_first_array, NDArrayHandle** out_second_array)
+{
+    if(!result)
+    {
+        AV *container, *in_args, *arg_grads;
+        int i;
+        container = newAV();
+        in_args = newAV();
+        arg_grads = newAV();
+        for (i = 0; i < *$1 ; i++) {
+            av_push(in_args, SvREFCNT_inc(SWIG_NewPointerObj(SWIG_as_voidptr((*$2)[i]), SWIGTYPE_p_MXNDArray, 0)));
+            av_push(arg_grads, SvREFCNT_inc(SWIG_NewPointerObj(SWIG_as_voidptr((*$3)[i]), SWIGTYPE_p_MXNDArray, 0)));
+        }
+        av_push(container, newRV_noinc((SV*)in_args));
+        av_push(container, newRV_noinc((SV*)arg_grads));
+        $result = newRV_noinc((SV*)container);
+        sv_2mortal($result);
+        argvi++;
+    }
+}
+
 %typemap(in,numinputs=0) (NDArrayHandle **out_grad) (NDArrayHandle* temp)
 {
     int vars = SvIV(ST(3));
@@ -629,6 +696,7 @@
     }
 }
 
+
 %typemap(argout) (NDArrayHandle** out_grad)
 {
     if(!result)
@@ -756,6 +824,7 @@
     $1 = &name_temp;
     $2 = &desc_temp;
     $3 = &num_args_temp;
+    *$3 = 0;
     $4 = &names_temp;
     $5 = &types_temp;
     $6 = &descs_temp;
@@ -815,7 +884,8 @@
 {
     $1 = &name_temp; 
     $2 = &desc_temp;
-    $3 = &num_args_temp; 
+    $3 = &num_args_temp;
+    *$3 = 0;
     $4 = &names_temp;
     $5 = &types_temp;
     $6 = &descs_temp;
@@ -861,7 +931,8 @@
 
 %typemap(in,numinputs=0) (mx_uint *out) (mx_uint temp), (size_t *out) (size_t temp)
 {
-    $1 = &temp; 
+    $1 = &temp;
+    *$1 = 0;
 }
 
 %typemap(argout) (mx_uint *out), (size_t *out)
diff --git a/perl-package/AI-NNVMCAPI/Changes b/perl-package/AI-NNVMCAPI/Changes
index 8c944d9..62aa042 100644
--- a/perl-package/AI-NNVMCAPI/Changes
+++ b/perl-package/AI-NNVMCAPI/Changes
@@ -1,5 +1,8 @@
 Revision history for Perl extension AI::NNVMCAPI.
 
+1.3     Tue Jun 26 20:57:40 PDT 2018
+        - Major update, Gluon interface updated to parity with Python's API
+
 1.2     Sun Mar  4 16:29:19 PST 2018
         - Support for sparse tensors
 
diff --git a/perl-package/AI-NNVMCAPI/META.json b/perl-package/AI-NNVMCAPI/META.json
index 4457e6f..3851c9d 100644
--- a/perl-package/AI-NNVMCAPI/META.json
+++ b/perl-package/AI-NNVMCAPI/META.json
@@ -37,5 +37,5 @@
       }
    },
    "release_status" : "stable",
-   "version" : "1.2"
+   "version" : "1.3"
 }
diff --git a/perl-package/AI-NNVMCAPI/META.yml b/perl-package/AI-NNVMCAPI/META.yml
index e7c01f8..e462637 100644
--- a/perl-package/AI-NNVMCAPI/META.yml
+++ b/perl-package/AI-NNVMCAPI/META.yml
@@ -19,4 +19,4 @@ no_index:
     - inc
 requires:
   Test::More: '0'
-version: '1.2'
+version: '1.3'
diff --git a/perl-package/AI-NNVMCAPI/README b/perl-package/AI-NNVMCAPI/README
index cc7c2f1..092fd31 100644
--- a/perl-package/AI-NNVMCAPI/README
+++ b/perl-package/AI-NNVMCAPI/README
@@ -1,4 +1,4 @@
-AI-NNVMCAPI version 1.2
+AI-NNVMCAPI version 1.3
 =====================
 
 Swig interface to MXNet c api.
diff --git a/perl-package/AI-NNVMCAPI/lib/AI/NNVMCAPI.pm b/perl-package/AI-NNVMCAPI/lib/AI/NNVMCAPI.pm
index 81453cd..92a8f74 100644
--- a/perl-package/AI-NNVMCAPI/lib/AI/NNVMCAPI.pm
+++ b/perl-package/AI-NNVMCAPI/lib/AI/NNVMCAPI.pm
@@ -18,7 +18,7 @@
 package AI::NNVMCAPI;
 use base qw(DynaLoader);
 bootstrap AI::NNVMCAPI;
-our $VERSION = '1.2';
+our $VERSION = '1.3';
 1;
 __END__
 
diff --git a/perl-package/AI-NNVMCAPI/nnvm_typemaps.i b/perl-package/AI-NNVMCAPI/nnvm_typemaps.i
index 818ed1e..e64b3c9 100644
--- a/perl-package/AI-NNVMCAPI/nnvm_typemaps.i
+++ b/perl-package/AI-NNVMCAPI/nnvm_typemaps.i
@@ -96,6 +96,7 @@
 %typemap(in,numinputs=0) (int *out) (int temp)
 {
     $1 = &temp;
+    *$1 = 0;
 }
 
 %typemap(argout) (int *out)
@@ -112,6 +113,7 @@
                          (mx_uint *out_size, const char ***out_array) (mx_uint temp_size, char** temp)
 {
     $1 = &temp_size;
+    *$1 = 0;
     $2 = &temp;
 }
 
@@ -139,6 +141,7 @@
 %typemap(in,numinputs=0) (nn_uint *half_of_out_size, const char ***out_array) (nn_uint temp_size, char **temp)
 {
     $1 = &temp_size;
+    *$1 = 0;
     $2 = &temp;
 }
 %typemap(argout) (nn_uint *half_of_out_size, const char ***out_array)
@@ -279,6 +282,7 @@
 %typemap(in,numinputs=0) (nn_uint *out_size, OpHandle** out_array) (nn_uint temp_num, OpHandle* temp)
 {
     $1 = &temp_num;
+    *$1 = 0;
     $2 = &temp;
 }
 %typemap(argout) (nn_uint *out_size, OpHandle** out_array)
@@ -303,6 +307,7 @@
 %typemap(in,numinputs=0) (nn_uint *out_size, SymbolHandle** out_array) (nn_uint temp_num, SymbolHandle* temp)
 {
     $1 = &temp_num;
+    *$1 = 0;
     $2 = &temp;
 }
 %typemap(argout) (nn_uint *out_size, SymbolHandle** out_array)