You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ib...@apache.org on 2019/09/11 04:20:32 UTC
[incubator-mxnet] 01/01: julia: fix `mx.forward` kwargs checking
This is an automated email from the ASF dual-hosted git repository.
iblis pushed a commit to branch ib/fix-forward
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit bb8a6f42b89b090c9b134d5710157606e9b99494
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Wed Sep 11 04:18:45 2019 +0000
julia: fix `mx.forward` kwargs checking
close https://github.com/dmlc/MXNet.jl/issues/431
---
julia/src/executor.jl | 2 +-
julia/test/unittest/bind.jl | 15 +++++++++++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/julia/src/executor.jl b/julia/src/executor.jl
index e565617..37f2dde 100644
--- a/julia/src/executor.jl
+++ b/julia/src/executor.jl
@@ -176,7 +176,7 @@ end
function forward(self::Executor; is_train::Bool = false, kwargs...)
for (k,v) in kwargs
- @assert(k ∈ self.arg_dict, "Unknown argument $k")
+ @assert(k ∈ keys(self.arg_dict), "Unknown argument $k")
@assert(isa(v, NDArray), "Keyword argument $k must be an NDArray")
copy!(self.arg_dict[k], v)
end
diff --git a/julia/test/unittest/bind.jl b/julia/test/unittest/bind.jl
index 0ae0ab4..a221733 100644
--- a/julia/test/unittest/bind.jl
+++ b/julia/test/unittest/bind.jl
@@ -84,11 +84,26 @@ function test_arithmetic()
end
end
+function test_forward()
+ # forward with data keyword argument
+ x = @var x
+ y = x .+ 42
+
+ A = 1:5
+ B = A .+ 42
+
+ e = bind(y, args = Dict(:x => NDArray(24:28)))
+ z = forward(e, x = NDArray(A))[1]
+
+ @test copy(z) == collect(B)
+end
+
################################################################################
# Run tests
################################################################################
@testset "Bind Test" begin
test_arithmetic()
+ test_forward()
end
end