You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ib...@apache.org on 2019/01/18 03:08:06 UTC
[incubator-mxnet] branch master updated: julia: fix `argmax` for
NDArray (#13871)
This is an automated email from the ASF dual-hosted git repository.
iblis pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 61847be julia: fix `argmax` for NDArray (#13871)
61847be is described below
commit 61847bebf5cf807680740542afeeacda5231ace9
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Fri Jan 18 11:07:29 2019 +0800
julia: fix `argmax` for NDArray (#13871)
- fix 0-based index output to 1-based index
close #13786
---
julia/src/ndarray.jl | 65 ++++++++++++++++++++++++++++++++++++++++++
julia/test/unittest/ndarray.jl | 46 ++++++++++++++++++++++++++++++
2 files changed, 111 insertions(+)
diff --git a/julia/src/ndarray.jl b/julia/src/ndarray.jl
index dad9b59..6987d57 100644
--- a/julia/src/ndarray.jl
+++ b/julia/src/ndarray.jl
@@ -1273,6 +1273,67 @@ Base.prod(x::NDArray; dims = :) = _prod(x, dims)
@_remap _prod(x::NDArray, ::Colon) prod(x)
@_remap _prod(x::NDArray, dims) prod(x; axis = 0 .- dims, keepdims = true)
+# TODO: support CartesianIndex ?
+"""
+ argmax(x::NDArray; dims) -> indices
+
+Note that `NaN` is skipped during comparison.
+This is different from Julia `Base.argmax`.
+
+## Examples
+
+```julia-repl
+julia> x = NDArray([0. 1 2; 3 4 5])
+2×3 NDArray{Float64,2} @ CPU0:
+ 0.0 1.0 2.0
+ 3.0 4.0 5.0
+
+julia> argmax(x, dims = 1)
+1×3 NDArray{Float64,2} @ CPU0:
+ 2.0 2.0 2.0
+
+julia> argmax(x, dims = 2)
+2×1 NDArray{Float64,2} @ CPU0:
+ 3.0
+ 3.0
+```
+
+See also [`argmin`](@ref mx.argmin).
+"""
+Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1
+@_remap _argmax(x::NDArray, ::Colon) argmax(x)
+@_remap _argmax(x::NDArray, dims) argmax(x; axis = 0 .- dims, keepdims = true)
+
+"""
+ argmin(x::NDArray; dims) -> indices
+
+Note that `NaN` is skipped during comparison.
+This is different from Julia `Base.argmin`.
+
+## Examples
+
+```julia-repl
+julia> x = NDArray([0. 1 2; 3 4 5])
+2×3 NDArray{Float64,2} @ CPU0:
+ 0.0 1.0 2.0
+ 3.0 4.0 5.0
+
+julia> argmax(x, dims = 1)
+1×3 NDArray{Float64,2} @ CPU0:
+ 2.0 2.0 2.0
+
+julia> argmax(x, dims = 2)
+2×1 NDArray{Float64,2} @ CPU0:
+ 3.0
+ 3.0
+```
+
+See also [`argmax`](@ref mx.argmax).
+"""
+Base.argmin(x::NDArray; dims = :) = _argmin(x, dims) .+ 1
+@_remap _argmin(x::NDArray, ::Colon) argmin(x)
+@_remap _argmin(x::NDArray, dims) argmin(x; axis = 0 .- dims, keepdims = true)
+
_nddoc[:clip] = _nddoc[:clip!] =
"""
clip(x::NDArray, min, max)
@@ -1734,6 +1795,10 @@ const _op_import_bl = [ # import black list; do not import these funcs
"broadcast_axis",
"broadcast_axes",
"broadcast_hypot",
+
+ # reduction
+ "argmax",
+ "argmin",
]
macro _import_ndarray_functions()
diff --git a/julia/test/unittest/ndarray.jl b/julia/test/unittest/ndarray.jl
index 9ca4ba2..85328ff 100644
--- a/julia/test/unittest/ndarray.jl
+++ b/julia/test/unittest/ndarray.jl
@@ -1434,6 +1434,50 @@ function test_hypot()
@test copy(z) == C
end # function test_hypot
+function test_argmax()
+ @info "NDArray::argmax"
+ let
+ A = [1. 5 3;
+ 4 2 6]
+ x = NDArray(A)
+
+ @test copy(argmax(x, dims = 1)) == [2 1 2]
+ @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
+ end
+
+ @info "NDArray::argmax::NaN"
+ let
+ A = [1. 5 3;
+ NaN 2 6]
+ x = NDArray(A)
+
+ @test copy(argmax(x, dims = 1)) == [1 1 2]
+ @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
+ end
+end
+
+function test_argmin()
+ @info "NDArray::argmin"
+ let
+ A = [1. 5 3;
+ 4 2 6]
+ x = NDArray(A)
+
+ @test copy(argmin(x, dims = 1)) == [1 2 1]
+ @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
+ end
+
+ @info "NDArray::argmin::NaN"
+ let
+ A = [1. 5 3;
+ NaN 2 6]
+ x = NDArray(A)
+
+ @test copy(argmin(x, dims = 1)) == [1 2 1]
+ @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
+ end
+end
+
################################################################################
# Run tests
################################################################################
@@ -1479,6 +1523,8 @@ end # function test_hypot
test_broadcast_to()
test_broadcast_axis()
test_hypot()
+ test_argmax()
+ test_argmin()
end
end