You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ib...@apache.org on 2019/12/21 18:08:55 UTC

[incubator-mxnet] 01/01: julia: porting `current_context`

This is an automated email from the ASF dual-hosted git repository.

iblis pushed a commit to branch ib/jl-ctx-current
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 3bb242b6f6d4656776aa27373f4f328e51ee4c02
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Sat Dec 21 18:06:31 2019 +0000

    julia: porting `current_context`
    
    - And introduce a new macro for changing default context
      `@with_context`
---
 julia/src/context.jl       | 65 +++++++++++++++++++++++++++++++++++++++++++++-
 julia/src/ndarray/array.jl | 18 +++++++------
 julia/src/ndarray/type.jl  |  2 +-
 3 files changed, 75 insertions(+), 10 deletions(-)

diff --git a/julia/src/context.jl b/julia/src/context.jl
index 68e6913..c48bb40 100644
--- a/julia/src/context.jl
+++ b/julia/src/context.jl
@@ -31,11 +31,50 @@ struct Context
   Context(dev_type::CONTEXT_TYPE, dev_id::Integer = 0) = new(dev_type, dev_id)
 end
 
+const _default_ctx = Ref{Context}(Context(CPU, 0))
+
 Context(dev_type::Integer, dev_id::Integer = 0) =
   Context(convert(CONTEXT_TYPE, dev_type), dev_id)
 
 Base.show(io::IO, ctx::Context) =
-  print(io, "$(ctx.device_type)$(ctx.device_id)")
+  print(io, lowercase("$(ctx.device_type)$(ctx.device_id)"))
+
+function _with_context(dev_type::Expr, dev_id::Integer, e::Expr)
+  global _default_ctx
+  quote
+    ctx = current_context()
+    ctx′ = Context($dev_type, $dev_id)
+    $_default_ctx[] = ctx′
+    try
+      return $e
+    finally
+      $_default_ctx[] = ctx
+    end
+  end
+end
+
+"""
+    @with_context device_type [device_id] expr
+
+Change the default context in the following expression.
+
+# Examples
+```jl-repl
+julia> mx.@with_context mx.GPU begin
+         mx.zeros(2, 3)
+       end
+2×3 NDArray{Float32,2} @ gpu0:
+ 0.0f0  0.0f0  0.0f0
+ 0.0f0  0.0f0  0.0f0
+```
+"""
+macro with_context(dev_type::Expr, e::Expr)
+  _with_context(dev_type, 0, e)
+end
+
+macro with_context(dev_type::Expr, dev_id::Integer, e::Expr)
+  _with_context(dev_type, dev_id, e)
+end
 
 """
     cpu(dev_id)
@@ -86,3 +125,27 @@ function gpu_memory_info(dev_id = 0)
   @mxcall :MXGetGPUMemoryInformation64 (Cint, Ref{UInt64}, Ref{UInt64}) dev_id free n
   free[], n[]
 end
+
+"""
+    current_context()
+
+Return the current context.
+
+By default,  `mx.cpu()` is used for all the computations
+and it can be overridden by using the `@with_context` macro.
+
+# Examples
+```jl-repl
+julia> mx.current_context()
+cpu0
+
+julia> mx.@with_context mx.GPU 1 begin  # Context changed in the following code block
+         mx.current_context()
+       end
+gpu1
+
+julia> mx.current_context()
+cpu0
+```
+"""
+current_context() = _default_ctx[]
diff --git a/julia/src/ndarray/array.jl b/julia/src/ndarray/array.jl
index b71e5dd..2cd9c2e 100644
--- a/julia/src/ndarray/array.jl
+++ b/julia/src/ndarray/array.jl
@@ -28,13 +28,14 @@ Base.similar(x::NDArray{T,N}; writable = x.writable, ctx = context(x)) where {T,
   NDArray{T,N}(undef, size(x)...; writable = writable, ctx = ctx)
 
 """
-    zeros([DType], dims, [ctx::Context = cpu()])
+    zeros([DType], dims, ctx::Context = current_context())
     zeros([DType], dims...)
     zeros(x::NDArray)
 
 Create zero-ed `NDArray` with specific shape and type.
 """
-function zeros(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
+function zeros(::Type{T}, dims::NTuple{N,Int},
+               ctx::Context = current_context()) where {N,T<:DType}
   x = NDArray{T}(undef, dims..., ctx = ctx)
   x[:] = zero(T)
   x
@@ -42,7 +43,7 @@ end
 
 zeros(::Type{T}, dims::Int...) where {T<:DType} = zeros(T, dims)
 
-zeros(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
+zeros(dims::NTuple{N,Int}, ctx::Context = current_context()) where N =
   zeros(MX_float, dims, ctx)
 zeros(dims::Int...) = zeros(dims)
 
@@ -50,13 +51,14 @@ zeros(x::NDArray)::typeof(x)      = zeros_like(x)
 Base.zeros(x::NDArray)::typeof(x) = zeros_like(x)
 
 """
-    ones([DType], dims, [ctx::Context = cpu()])
+    ones([DType], dims, ctx::Context = current_context())
     ones([DType], dims...)
     ones(x::NDArray)
 
 Create an `NDArray` with specific shape & type, and initialize with 1.
 """
-function ones(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
+function ones(::Type{T}, dims::NTuple{N,Int},
+              ctx::Context = current_context()) where {N,T<:DType}
   arr = NDArray{T}(undef, dims..., ctx = ctx)
   arr[:] = one(T)
   arr
@@ -64,7 +66,7 @@ end
 
 ones(::Type{T}, dims::Int...) where T<:DType = ones(T, dims)
 
-ones(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
+ones(dims::NTuple{N,Int}, ctx::Context = current_context()) where N =
   ones(MX_float, dims, ctx)
 ones(dims::Int...) = ones(dims)
 
@@ -458,12 +460,12 @@ function Base.fill!(arr::NDArray, x)
 end
 
 """
-    fill(x, dims, ctx=cpu())
+    fill(x, dims, ctx = current_context())
     fill(x, dims...)
 
 Create an `NDArray` filled with the value `x`, like `Base.fill`.
 """
-function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = cpu()) where {T,N}
+function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = current_context()) where {T,N}
   arr = NDArray{T}(undef, dims, ctx = ctx)
   arr[:] = x
   arr
diff --git a/julia/src/ndarray/type.jl b/julia/src/ndarray/type.jl
index 8d90d63..e24c8929 100644
--- a/julia/src/ndarray/type.jl
+++ b/julia/src/ndarray/type.jl
@@ -116,7 +116,7 @@ end
 
 # UndefInitializer constructors
 NDArray{T,N}(::UndefInitializer, dims::NTuple{N,Integer};
-             writable = true, ctx::Context = cpu()) where {T,N} =
+             writable = true, ctx::Context = current_context()) where {T,N} =
   NDArray{T,N}(_ndarray_alloc(T, dims, ctx, false), writable)
 NDArray{T,N}(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {T,N} =
   NDArray{T,N}(undef, dims; kw...)