You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ib...@apache.org on 2019/01/14 12:45:00 UTC

[incubator-mxnet] branch ib/nd-argmax-argmin created (now 267e7c2)

This is an automated email from the ASF dual-hosted git repository.

iblis pushed a change to branch ib/nd-argmax-argmin
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git.


      at 267e7c2  julia: fix `argmax` for NDArray

This branch includes the following new commits:

     new 267e7c2  julia: fix `argmax` for NDArray

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[incubator-mxnet] 01/01: julia: fix `argmax` for NDArray

Posted by ib...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

iblis pushed a commit to branch ib/nd-argmax-argmin
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 267e7c2f470fe9ff3015988543f2363d5896aefc
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Mon Jan 14 20:40:46 2019 +0800

    julia: fix `argmax` for NDArray
    
    - fix 0-based index output to 1-based index
    
    close #13786
---
 julia/src/ndarray.jl           | 65 ++++++++++++++++++++++++++++++++++++++++++
 julia/test/unittest/ndarray.jl | 46 ++++++++++++++++++++++++++++++
 2 files changed, 111 insertions(+)

diff --git a/julia/src/ndarray.jl b/julia/src/ndarray.jl
index dad9b59..6987d57 100644
--- a/julia/src/ndarray.jl
+++ b/julia/src/ndarray.jl
@@ -1273,6 +1273,67 @@ Base.prod(x::NDArray; dims = :) = _prod(x, dims)
 @_remap _prod(x::NDArray, ::Colon) prod(x)
 @_remap _prod(x::NDArray, dims)    prod(x; axis = 0 .- dims, keepdims = true)
 
+# TODO: support CartesianIndex ?
+"""
+    argmax(x::NDArray; dims) -> indices
+
+Note that `NaN` is skipped during comparison.
+This is different from Julia `Base.argmax`.
+
+## Examples
+
+```julia-repl
+julia> x = NDArray([0. 1 2; 3 4 5])
+2×3 NDArray{Float64,2} @ CPU0:
+ 0.0  1.0  2.0
+ 3.0  4.0  5.0
+
+julia> argmax(x, dims = 1)
+1×3 NDArray{Float64,2} @ CPU0:
+ 2.0  2.0  2.0
+
+julia> argmax(x, dims = 2)
+2×1 NDArray{Float64,2} @ CPU0:
+ 3.0
+ 3.0
+```
+
+See also [`argmin`](@ref mx.argmin).
+"""
+Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1
+@_remap _argmax(x::NDArray, ::Colon) argmax(x)
+@_remap _argmax(x::NDArray, dims)    argmax(x; axis = 0 .- dims, keepdims = true)
+
+"""
+    argmin(x::NDArray; dims) -> indices
+
+Note that `NaN` is skipped during comparison.
+This is different from Julia `Base.argmin`.
+
+## Examples
+
+```julia-repl
+julia> x = NDArray([0. 1 2; 3 4 5])
+2×3 NDArray{Float64,2} @ CPU0:
+ 0.0  1.0  2.0
+ 3.0  4.0  5.0
+
+julia> argmax(x, dims = 1)
+1×3 NDArray{Float64,2} @ CPU0:
+ 2.0  2.0  2.0
+
+julia> argmax(x, dims = 2)
+2×1 NDArray{Float64,2} @ CPU0:
+ 3.0
+ 3.0
+```
+
+See also [`argmax`](@ref mx.argmax).
+"""
+Base.argmin(x::NDArray; dims = :) = _argmin(x, dims) .+ 1
+@_remap _argmin(x::NDArray, ::Colon) argmin(x)
+@_remap _argmin(x::NDArray, dims)    argmin(x; axis = 0 .- dims, keepdims = true)
+
 _nddoc[:clip] = _nddoc[:clip!] =
 """
     clip(x::NDArray, min, max)
@@ -1734,6 +1795,10 @@ const _op_import_bl = [  # import black list; do not import these funcs
     "broadcast_axis",
     "broadcast_axes",
     "broadcast_hypot",
+
+    # reduction
+    "argmax",
+    "argmin",
 ]
 
 macro _import_ndarray_functions()
diff --git a/julia/test/unittest/ndarray.jl b/julia/test/unittest/ndarray.jl
index 9ca4ba2..85328ff 100644
--- a/julia/test/unittest/ndarray.jl
+++ b/julia/test/unittest/ndarray.jl
@@ -1434,6 +1434,50 @@ function test_hypot()
   @test copy(z) == C
 end  # function test_hypot
 
+function test_argmax()
+  @info "NDArray::argmax"
+  let
+    A = [1. 5 3;
+         4 2 6]
+    x = NDArray(A)
+
+    @test copy(argmax(x, dims = 1)) == [2 1 2]
+    @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
+  end
+
+  @info "NDArray::argmax::NaN"
+  let
+    A = [1.  5 3;
+         NaN 2 6]
+    x = NDArray(A)
+
+    @test copy(argmax(x, dims = 1)) == [1 1 2]
+    @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
+  end
+end
+
+function test_argmin()
+  @info "NDArray::argmin"
+  let
+    A = [1. 5 3;
+         4 2 6]
+    x = NDArray(A)
+
+    @test copy(argmin(x, dims = 1)) == [1 2 1]
+    @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
+  end
+
+  @info "NDArray::argmin::NaN"
+  let
+    A = [1.  5 3;
+         NaN 2 6]
+    x = NDArray(A)
+
+    @test copy(argmin(x, dims = 1)) == [1 2 1]
+    @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
+  end
+end
+
 ################################################################################
 # Run tests
 ################################################################################
@@ -1479,6 +1523,8 @@ end  # function test_hypot
   test_broadcast_to()
   test_broadcast_axis()
   test_hypot()
+  test_argmax()
+  test_argmin()
 end
 
 end