You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ib...@apache.org on 2019/01/18 03:08:06 UTC

[incubator-mxnet] branch master updated: julia: fix `argmax` for NDArray (#13871)

This is an automated email from the ASF dual-hosted git repository.

iblis pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 61847be  julia: fix `argmax` for NDArray (#13871)
61847be is described below

commit 61847bebf5cf807680740542afeeacda5231ace9
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Fri Jan 18 11:07:29 2019 +0800

    julia: fix `argmax` for NDArray (#13871)
    
    - fix 0-based index output to 1-based index
    
    close #13786
---
 julia/src/ndarray.jl           | 65 ++++++++++++++++++++++++++++++++++++++++++
 julia/test/unittest/ndarray.jl | 46 ++++++++++++++++++++++++++++++
 2 files changed, 111 insertions(+)

diff --git a/julia/src/ndarray.jl b/julia/src/ndarray.jl
index dad9b59..6987d57 100644
--- a/julia/src/ndarray.jl
+++ b/julia/src/ndarray.jl
@@ -1273,6 +1273,67 @@ Base.prod(x::NDArray; dims = :) = _prod(x, dims)
 @_remap _prod(x::NDArray, ::Colon) prod(x)
 @_remap _prod(x::NDArray, dims)    prod(x; axis = 0 .- dims, keepdims = true)
 
+# TODO: support CartesianIndex ?
+"""
+    argmax(x::NDArray; dims) -> indices
+
+Note that `NaN` is skipped during comparison.
+This is different from Julia `Base.argmax`.
+
+## Examples
+
+```julia-repl
+julia> x = NDArray([0. 1 2; 3 4 5])
+2×3 NDArray{Float64,2} @ CPU0:
+ 0.0  1.0  2.0
+ 3.0  4.0  5.0
+
+julia> argmax(x, dims = 1)
+1×3 NDArray{Float64,2} @ CPU0:
+ 2.0  2.0  2.0
+
+julia> argmax(x, dims = 2)
+2×1 NDArray{Float64,2} @ CPU0:
+ 3.0
+ 3.0
+```
+
+See also [`argmin`](@ref mx.argmin).
+"""
+Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1
+@_remap _argmax(x::NDArray, ::Colon) argmax(x)
+@_remap _argmax(x::NDArray, dims)    argmax(x; axis = 0 .- dims, keepdims = true)
+
+"""
+    argmin(x::NDArray; dims) -> indices
+
+Note that `NaN` is skipped during comparison.
+This is different from Julia `Base.argmin`.
+
+## Examples
+
+```julia-repl
+julia> x = NDArray([0. 1 2; 3 4 5])
+2×3 NDArray{Float64,2} @ CPU0:
+ 0.0  1.0  2.0
+ 3.0  4.0  5.0
+
+julia> argmax(x, dims = 1)
+1×3 NDArray{Float64,2} @ CPU0:
+ 2.0  2.0  2.0
+
+julia> argmax(x, dims = 2)
+2×1 NDArray{Float64,2} @ CPU0:
+ 3.0
+ 3.0
+```
+
+See also [`argmax`](@ref mx.argmax).
+"""
+Base.argmin(x::NDArray; dims = :) = _argmin(x, dims) .+ 1
+@_remap _argmin(x::NDArray, ::Colon) argmin(x)
+@_remap _argmin(x::NDArray, dims)    argmin(x; axis = 0 .- dims, keepdims = true)
+
 _nddoc[:clip] = _nddoc[:clip!] =
 """
     clip(x::NDArray, min, max)
@@ -1734,6 +1795,10 @@ const _op_import_bl = [  # import black list; do not import these funcs
     "broadcast_axis",
     "broadcast_axes",
     "broadcast_hypot",
+
+    # reduction
+    "argmax",
+    "argmin",
 ]
 
 macro _import_ndarray_functions()
diff --git a/julia/test/unittest/ndarray.jl b/julia/test/unittest/ndarray.jl
index 9ca4ba2..85328ff 100644
--- a/julia/test/unittest/ndarray.jl
+++ b/julia/test/unittest/ndarray.jl
@@ -1434,6 +1434,50 @@ function test_hypot()
   @test copy(z) == C
 end  # function test_hypot
 
+function test_argmax()
+  @info "NDArray::argmax"
+  let
+    A = [1. 5 3;
+         4 2 6]
+    x = NDArray(A)
+
+    @test copy(argmax(x, dims = 1)) == [2 1 2]
+    @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
+  end
+
+  @info "NDArray::argmax::NaN"
+  let
+    A = [1.  5 3;
+         NaN 2 6]
+    x = NDArray(A)
+
+    @test copy(argmax(x, dims = 1)) == [1 1 2]
+    @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
+  end
+end
+
+function test_argmin()
+  @info "NDArray::argmin"
+  let
+    A = [1. 5 3;
+         4 2 6]
+    x = NDArray(A)
+
+    @test copy(argmin(x, dims = 1)) == [1 2 1]
+    @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
+  end
+
+  @info "NDArray::argmin::NaN"
+  let
+    A = [1.  5 3;
+         NaN 2 6]
+    x = NDArray(A)
+
+    @test copy(argmin(x, dims = 1)) == [1 2 1]
+    @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
+  end
+end
+
 ################################################################################
 # Run tests
 ################################################################################
@@ -1479,6 +1523,8 @@ end  # function test_hypot
   test_broadcast_to()
   test_broadcast_axis()
   test_hypot()
+  test_argmax()
+  test_argmin()
 end
 
 end