You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/01/19 20:38:36 UTC

[GitHub] [incubator-mxnet] haojin2 commented on a change in pull request #17360: add random.multivariate_normal, fix empty_like dtype problem, fix gat…

haojin2 commented on a change in pull request #17360: add random.multivariate_normal, fix empty_like dtype problem, fix gat…
URL: https://github.com/apache/incubator-mxnet/pull/17360#discussion_r368320815
 
 

 ##########
 File path: python/mxnet/numpy/random.py
 ##########
 @@ -242,6 +242,82 @@ def multinomial(n, pvals, size=None, **kwargs):
     return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs)
 
 
+# pylint: disable=unused-argument
+def multivariate_normal(mean, cov, size=None, check_valid=None, tol=None):
+    """
+    multivariate_normal(mean, cov, size=None, check_valid=None, tol=None)
+
+    Draw random samples from a multivariate normal distribution.
+
+    The multivariate normal, multinormal or Gaussian distribution is a
+    generalization of the one-dimensional normal distribution to higher
+    dimensions.  Such a distribution is specified by its mean and
+    covariance matrix.  These parameters are analogous to the mean
+    (average or "center") and variance (standard deviation, or "width,"
+    squared) of the one-dimensional normal distribution.
+
+    This operator is a little different from the one in official NumPy.
+    The official NumPy operator only accepts 1-D ndarray as mean and 2-D ndarray as cov,
+    whereas the operator in DeepNumPy supports batch operation and auto-broadcasting.
+
+    Both `mean` and `cov` may have any number of leading dimensions, which correspond
+    to a batch shape. They are not necessarily assumed to have the same batch shape,
+    just ones which can be broadcasted.
+
+    Parameters
+    ----------
+    mean : K-D ndarray, of shape (..., N)
+        Mean of the N-dimensional distribution.
+    cov : (K+1)-D ndarray, of shape (..., N, N)
+        Covariance matrix of the distribution. The last two dimensions must be symmetric and
+        positive-semidefinite for proper sampling.
+    size : int or tuple of ints, optional
+        Given a shape of, for example, ``(m,n,k)``,
+        ``m*n*k`` identically distributed batchs of samples are
+        generated, and packed in an `m`-by-`n`-by-`k` arrangement.
+        If no shape is specified, a batch of (`N`-D) sample is returned.
+    check_valid : { 'warn', 'raise', 'ignore' }, optional
+        Behavior when the covariance matrix is not positive semidefinite.
+        (Not supported)
+    tol : float, optional
+        Tolerance when checking the singular values in covariance matrix.
+        cov is cast to double before the check.
+        (Not supported)
+
+    Returns
+    -------
+    out : ndarray
+        The input shape of `mean` and `cov` should satisfy the requirements of broadcasting.
+        If the parameter `size` is not provided,
+        the output shape is ``np.broadcast(mean.shape, cov.shape[:-1])``.
+        Otherwise, the output shape is ``size + np.broadcast(mean.shape, cov.shape[:-1])``
+
+    Examples
+    --------
+    >>> mean = np.array([1, 2])
+    >>> cov = np.array([[1, 0], [0, 1]])
+    >>> x = np.random.multivariate_normal(mean, cov, (3, 3))
+    >>> x.shape
+    (3, 3, 2)
+    # Performs autobroadcasting when the batch shape of
+    # `mean` and `cov` is different but compatible.
+    >>> mean = np.zeros((3,2)) # shape (3, 2)
+    >>> cov = np.array([[1, 0], [0, 100]]) # shape (2, 2)
+    >>> x = np.random.multivariate_normal(mean, cov)
+    >>> x
+    array([[-1.6115597 , -8.726251  ],
+           [ 2.2425299 ,  2.8104177 ],
+           [ 0.36229908, -8.386591  ]])
+
+    The following is probably true, given that 0.6 is roughly twice the
+    standard deviation:
+    >>> list((x[0,0,:] - mean) < 0.6)
+    [True, True] # random
+    """
+    return _mx_nd_np.random.multivariate_normal(mean, cov, size=size,\
 
 Review comment:
   no need for backslash here, and fix alignment

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services