You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ib...@apache.org on 2019/09/11 04:20:32 UTC

[incubator-mxnet] 01/01: julia: fix `mx.forward` kwargs checking

This is an automated email from the ASF dual-hosted git repository.

iblis pushed a commit to branch ib/fix-forward
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit bb8a6f42b89b090c9b134d5710157606e9b99494
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Wed Sep 11 04:18:45 2019 +0000

    julia: fix `mx.forward` kwargs checking
    
    close https://github.com/dmlc/MXNet.jl/issues/431
---
 julia/src/executor.jl       |  2 +-
 julia/test/unittest/bind.jl | 15 +++++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/julia/src/executor.jl b/julia/src/executor.jl
index e565617..37f2dde 100644
--- a/julia/src/executor.jl
+++ b/julia/src/executor.jl
@@ -176,7 +176,7 @@ end
 
 function forward(self::Executor; is_train::Bool = false, kwargs...)
   for (k,v) in kwargs
-    @assert(k ∈ self.arg_dict, "Unknown argument $k")
+    @assert(k ∈ keys(self.arg_dict), "Unknown argument $k")
     @assert(isa(v, NDArray), "Keyword argument $k must be an NDArray")
     copy!(self.arg_dict[k], v)
   end
diff --git a/julia/test/unittest/bind.jl b/julia/test/unittest/bind.jl
index 0ae0ab4..a221733 100644
--- a/julia/test/unittest/bind.jl
+++ b/julia/test/unittest/bind.jl
@@ -84,11 +84,26 @@ function test_arithmetic()
   end
 end
 
+function test_forward()
+  # forward with data keyword argument
+  x = @var x
+  y = x .+ 42
+
+  A = 1:5
+  B = A .+ 42
+
+  e = bind(y, args = Dict(:x => NDArray(24:28)))
+  z = forward(e, x = NDArray(A))[1]
+
+  @test copy(z) == collect(B)
+end
+
 ################################################################################
 # Run tests
 ################################################################################
 @testset "Bind Test" begin
   test_arithmetic()
+  test_forward()
 end
 
 end