You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2019/12/24 03:54:18 UTC
[singa] branch master updated: SINGA-487 Update comments based on
Google Python Style Guide
This is an automated email from the ASF dual-hosted git repository.
wangwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/master by this push:
new a39f377 SINGA-487 Update comments based on Google Python Style Guide
new 68c803e Merge pull request #568 from chrishkchris/dist_op
a39f377 is described below
commit a39f3777cb92060181c9c21534aac35b99886051
Author: Chris Yeung <ch...@yahoo.com.hk>
AuthorDate: Thu Dec 19 16:22:40 2019 +0800
SINGA-487 Update comments based on Google Python Style Guide
---
python/singa/opt.py | 193 +++++++++++++++++++++++++++++++++++++++++++---------
1 file changed, 161 insertions(+), 32 deletions(-)
diff --git a/python/singa/opt.py b/python/singa/opt.py
index 5df9a9f..d064739 100755
--- a/python/singa/opt.py
+++ b/python/singa/opt.py
@@ -24,7 +24,7 @@ from . import singa_wrap as singa
class Optimizer(object):
- r"""Base optimizer.
+ """Base optimizer.
Args:
config (Dict): specify the default values of configurable variables.
@@ -37,7 +37,7 @@ class Optimizer(object):
self.param2state = {}
def update(self, param, grad):
- r"""Update the param values with given gradients.
+ """Update the param values with given gradients.
Args:
param(Tensor): param values to be updated in-place
@@ -47,7 +47,7 @@ class Optimizer(object):
pass
def step(self):
- r"""To increment the step counter"""
+ """To increment the step counter"""
self.iter += 1
def register(self, param_group, config):
@@ -64,7 +64,7 @@ class Optimizer(object):
class SGD(Optimizer):
- r"""Implements stochastic gradient descent (optionally with momentum).
+ """Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
@@ -76,7 +76,7 @@ class SGD(Optimizer):
dampening(float, optional): dampening for momentum(default: 0)
nesterov(bool, optional): enables Nesterov momentum(default: False)
- Example:
+ Typical usage example:
>> > from singa import opt
>> > optimizer = opt.SGD(lr=0.1, momentum=0.9)
>> > optimizer.update()
@@ -124,7 +124,7 @@ class SGD(Optimizer):
def update(self, param, grad):
"""Performs a single optimization step.
- Arguments:
+ Args:
param(Tensor): param values to be update in-place
grad(Tensor): param gradients; the values may be updated
in this function; cannot use it anymore
@@ -161,25 +161,51 @@ class SGD(Optimizer):
singa.Axpy(-group['lr'], grad.data, param.data)
def backward_and_update(self, loss):
+ """Performs backward propagation from the loss and parameter update.
+
+ From the loss, it performs backward propagation to get the gradients
+ and do the parameter update.
+
+ Args:
+ loss(Tensor): loss is the objective function of the deep learning model
+ optimization, e.g. for classification problem it can be the output of the
+ softmax_cross_entropy function.
+ """
for p, g in autograd.backward(loss):
self.update(p, g)
class DistOpt(object):
+ """The class is designed to wrap an optimizer to do distributed training.
- def __init__(self, opt=SGD(), nccl_id=None, gpu_num=None, gpu_per_node=None, buffSize=4194304):
- # The class is designed to wrap an optimizer to do disttributed training.
- # opt: The optimizer to be wrapped. nDev: number of devices(GPUs) a
- # process will control/use.
- # nccl_id: an nccl id holder object for a unique communication id
- # gpu_num: the GPU id in a single node
- # gpu_per_node: the number of GPUs in a single node
- # buffSize: the buffSize used in nccl communicator, default is 16 MB
-
- # world_size: total number of processes.
- # rank_in_local: local rank of a process on the current node.
- # rank_in_global: global rank of a process
+ This class is used to wrap an optimizer object to perform distributed training based
+ on multiprocessing. Each process has an individual rank, which gives information of
+ which GPU the individual process is using. The training data is partitioned, so that
+ each process can evaluate the sub-gradient based on the partitioned training data.
+ Once the sub-graident is calculated on each processes, the overall stochastic gradient
+ is obtained by all-reducing the sub-gradients evaluated by all processes. The all-reduce
+ operation is supported by the NVidia Collective Communication Library (NCCL).
+
+ Args:
+ opt(Optimizer): The optimizer to be wrapped.
+ nccl_id(NcclIdHolder): an nccl id holder object for a unique communication id
+ gpu_num(int): the GPU id in a single node
+ gpu_per_node(int): the number of GPUs in a single node
+ buffSize(int): the buffSize in terms of number of elements used in nccl communicator
+
+ Attributes:
+ world_size(int): total number of processes
+ rank_in_local(int): local rank of a process on the current node
+ rank_in_global(int): global rank of a process
+
+ Typical usage example:
+ >> > from singa import opt
+ >> > optimizer = opt.SGD(lr=0.1, momentum=0.9)
+ >> > optimizer = opt.DistOpt(sgd)
+
+ """
+ def __init__(self, opt=SGD(), nccl_id=None, gpu_num=None, gpu_per_node=None, buffSize=4194304):
self.opt = opt
if nccl_id is None:
# constructure for application using MPI
@@ -193,30 +219,78 @@ class DistOpt(object):
self.rank_in_global = self.communicator.MPIRankInGlobal
def update(self, param, grad):
+ """Performs a single optimization step.
+
+ Args:
+ param(Tensor): param values to be update
+ grad(Tensor): param gradients
+ """
grad /= self.world_size
self.opt.update(param, grad)
def all_reduce(self, tensor):
+ """Performs all reduce of a tensor for distributed training.
+
+ Args:
+ tensor(Tensor): a tensor to be all-reduced
+ """
self.communicator.synch(tensor)
def fused_all_reduce(self, tensor):
+ """Performs all reduce of the tensors after fusing them in a buffer.
+
+ Args:
+ tensor(List of Tensors): a list of tensors to be all-reduced
+ """
tensor = singa.VecTensor(tensor)
self.communicator.fusedSynch(tensor)
def all_reduce_half(self, tensor):
+ """Performs all reduce of a tensor after converting to FP16.
+
+ Args:
+ tensor(Tensor): a tensor to be all-reduced
+ """
self.communicator.synchHalf(tensor)
def fused_all_reduce_half(self, tensor):
+ """Performs all reduce of the tensors after fusing and converting them to FP16.
+
+ Args:
+ tensor(List of Tensors): a list of tensors to be all-reduced
+ """
tensor = singa.VecTensor(tensor)
self.communicator.fusedSynchHalf(tensor)
def sparsification(self, tensor, accumulation, spars, topK):
+ """Performs all reduce of a tensor after sparsification.
+
+ Args:
+ tensor(Tensor): a tensor to be all-reduced
+ accumulation(Tensor): local gradient accumulation
+ spars(float): a parameter to control sparsity as defined below
+ topK(bool): When topK is False, it sparsifies the gradient with absolute
+ value >= sparsWhen topK is True, it sparsifies a fraction of total gradient
+ number equals to spars, E.g. when spars = 0.01, it sparsifies 1 % of the
+ total gradient elements
+ """
if accumulation is None:
self.communicator.sparsification(tensor, spars, topK)
else:
self.communicator.sparsification(tensor, accumulation, spars, topK)
def fused_sparsification(self, tensor, accumulation, spars, topK):
+ """Performs all reduce of the tensors after fusing and sparsification.
+
+ Args:
+ tensor(List of Tensors): a list of tensors to be all-reduced
+ accumulation(Tensor): local gradient accumulation
+ spars(float): a parameter to control sparsity as defined below
+ topK(bool): When topK is False, it sparsifies the gradient with absolute
+ value >= sparsWhen topK is True, it sparsifies a fraction of total gradient
+ number equals to spars, E.g. when spars = 0.01, it sparsifies 1 % of the
+ total gradient elements
+ """
tensor = singa.VecTensor(tensor)
if accumulation is None:
self.communicator.fusedSparsification(tensor, spars, topK)
@@ -224,11 +298,26 @@ class DistOpt(object):
self.communicator.fusedSparsification(tensor, accumulation, spars, topK)
def wait(self):
+ """Wait for the cuda streams used by the communicator to finish their operations."""
self.communicator.wait()
def backward_and_update(self, loss, threshold = 2097152):
- # backward propagation from the loss and parameter update
- # it applies tensor fusion which fuses all the tensor smaller than the threshold value
+ """Performs backward propagation from the loss and parameter update.
+
+ From the loss, it performs backward propagation to get the gradients and do the parameter
+ update. For gradient communication, it fuses all the tensor smaller than the threshold
+ value to reduce network latency.
+
+ Args:
+ loss(Tensor): loss is the objective function of the deep learning model
+ optimization, e.g. for classification problem it can be the output of the
+ softmax_cross_entropy function.
+ threshold(int): threshold is a parameter to control performance in fusing
+ the tensors. For the tensors of sizes smaller than threshold, they are to
+ be accumulated and fused before the all reduce operation. For the tensors
+ of its size larger than the threshold value, they are to be reduced directly
+ without fusion.
+ """
plist = []
acc = 0
glist = []
@@ -252,9 +341,27 @@ class DistOpt(object):
self.update(p, g)
def backward_and_update_half(self, loss, threshold = 2097152, clipping = False, clip_Value = 100):
- # THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
- # It converts the gradients to 16 bits half precision format before allreduce
- # To assist training, this functions provide an option to perform gradient clipping
+ """Performs backward propagation and parameter update, with FP16 precision communication.
+
+ THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
+ From the loss, it performs backward propagation to get the gradients and do the parameter
+ update. For gradient communication, it fuses all the tensor smaller than the threshold value
+ to reduce network latency, as well as converting them to FP16 half precision format before
+ sending them out. To assist training, this functions provide an option to perform gradient
+ clipping.
+
+ Args:
+ loss(Tensor): loss is the objective function of the deep learning model
+ optimization, e.g. for classification problem it can be the output of the
+ softmax_cross_entropy function.
+ threshold(int): threshold is a parameter to control performance in fusing
+ the tensors. For the tensors of sizes smaller than threshold, they are to
+ be accumulated and fused before the all reduce operation. For the tensors
+ of its size larger than the threshold value, they are to be reduced directly
+ without fusion.
+ clipping(bool): a boolean flag to choose whether to clip the gradient value
+ clip_value(float): the clip value to be used when clipping is True
+ """
plist = []
acc = 0
glist = []
@@ -280,10 +387,29 @@ class DistOpt(object):
self.update(p, g)
def backward_and_partial_update(self, loss, threshold = 2097152):
- # THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
- # It performs asychronous training where one parameter partition is all-reduced per iteration
- # The size of the parameter partition depends on the threshold value
- # self.partial is the counter to determine which partition to perform all-reduce
+ """Performs backward propagation from the loss and parameter update using asychronous training.
+
+ THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
+ From the loss, it performs backward propagation to get the gradients and do the parameter
+ update. It fuses the tensors smaller than the threshold value to reduce network latency,
+ as well as performing asychronous training where one parameter partition is all-reduced
+ per iteration. The size of the parameter partition depends on the threshold value.
+
+ Args:
+ loss(Tensor): loss is the objective function of the deep learning model
+ optimization, e.g. for classification problem it can be the output of the
+ softmax_cross_entropy function.
+ threshold(int): threshold is a parameter to control performance in fusing
+ the tensors. For the tensors of sizes smaller than threshold, they are to
+ be accumulated and fused before the all reduce operation. For the tensors
+ of its size larger than the threshold value, they are to be reduced directly
+ without fusion.
+
+ Attributes:
+ self.partial(int): A counter to determine which partition to perform all-reduce.
+ This counter resets to zero automatlly after an update cycle of the full parameter
+ set.
+ """
if not hasattr(self, "partial"):
self.partial = 0
self.partial += 1
@@ -330,12 +456,15 @@ class DistOpt(object):
self.partial = 0
def backward_and_spars_update(self, loss, threshold = 2097152, spars = 0.05, topK = False, corr = True):
- r"""THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
- Performs backward propagation from the loss and parameter update with sparsification.
- It fuses the tensors with size smaller than the threshold value to reduce network latency, as well
- as using sparsification scheme to transfer only gradient elements which are significant.
+ """ Performs backward propagation from the loss and parameter update with sparsification.
- Arguments:
+ THIS IS A EXPERIMENTAL FUNCTION FOR RESEARCH PURPOSE:
+ From the loss, it performs backward propagation to get the gradients and do the parameter
+ update. It fuses the tensors with size smaller than the threshold value to reduce network
+ latency, as well as using sparsification schemes to transfer only the gradient elements which
+ are significant.
+
+ Args:
loss(Tensor): loss is the objective function of the deep learning model
optimization, e.g. for classification problem it can be the output of the
softmax_cross_entropy function.