You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/07/26 05:24:57 UTC

[GitHub] [incubator-mxnet] szha commented on a change in pull request #17298: [MXNET-1438] Adding SDML loss function

szha commented on a change in pull request #17298:
URL: https://github.com/apache/incubator-mxnet/pull/17298#discussion_r460480501



##########
File path: python/mxnet/gluon/loss.py
##########
@@ -930,3 +930,118 @@ def _cosine_similarity(self, F, x, y, axis=-1):
         else:
             eps_arr = F.full((1, 1), 1e-12)
         return (x_dot_y / F.broadcast_maximum(x_norm * y_norm, eps_arr))
+
+
+class SDMLLoss(Loss):
+    r"""Calculates Batchwise Smoothed Deep Metric Learning (SDML) Loss given two input tensors and a smoothing weight
+    SDM Loss learns similarity between paired samples by using unpaired samples in the minibatch
+    as potential negative examples.
+
+    The loss is described in greater detail in
+    "Large Scale Question Paraphrase Retrieval with Smoothed Deep Metric Learning."
+    - by Bonadiman, Daniele, Anjishnu Kumar, and Arpit Mittal.  arXiv preprint arXiv:1905.12786 (2019).
+    URL: https://arxiv.org/pdf/1905.12786.pdf
+
+    According to the authors, this loss formulation achieves comparable or higher accuracy to
+    Triplet Loss but converges much faster.
+    The loss assumes that the items in both tensors in each minibatch
+    are aligned such that x1[0] corresponds to x2[0] and all other datapoints in the minibatch are unrelated.
+    `x1` and  `x2` are minibatches of vectors.
+
+    Parameters
+    ----------
+    smoothing_parameter : float
+        Probability mass to be distributed over the minibatch. Must be < 1.0.
+    weight : float or None
+        Global scalar weight for loss.
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+
+    Inputs:
+        - **x1**: Minibatch of data points with shape (batch_size, vector_dim)
+        - **x2**: Minibatch of data points with shape (batch_size, vector_dim)
+          Each item in x2 is a positive sample for the same index in x1.
+          That is, x1[0] and x2[0] form a positive pair, x1[1] and x2[1] form a positive pair - and so on.
+          All data points in different rows should be decorrelated
+
+    Outputs:
+        - **loss**: loss tensor with shape (batch_size,).
+    """
+
+    def __init__(self, smoothing_parameter=0.3, weight=1., batch_axis=0, **kwargs):
+        super(SDMLLoss, self).__init__(weight, batch_axis, **kwargs)
+        self.kl_loss = KLDivLoss(from_logits=True)
+        self.smoothing_parameter = smoothing_parameter # Smoothing probability mass
+
+    def _compute_distances(self, x1, x2):
+        """
+        This function computes the euclidean distance between every vector
+        in the two batches in input.
+        """
+
+        # extracting sizes expecting [batch_size, dim]
+        assert x1.shape == x2.shape
+        batch_size, dim = x1.shape
+        # expanding both tensor form [batch_size, dim] to [batch_size, batch_size, dim]
+        x1_ = x1.expand_dims(1).broadcast_to([batch_size, batch_size, dim])
+        x2_ = x2.expand_dims(0).broadcast_to([batch_size, batch_size, dim])
+        # pointwise squared differences
+        squared_diffs = (x1_ - x2_)**2
+        # sum of squared differences distance
+        return squared_diffs.sum(axis=2)
+
+
+    def _compute_labels(self, F, batch_size):
+        """
+        The function creates the label matrix for the loss.
+        It is an identity matrix of size [BATCH_SIZE x BATCH_SIZE]
+        labels:
+            [[1, 0]
+             [0, 1]]
+
+        after the proces the labels are smoothed by a small amount to
+        account for errors.
+
+        labels:
+            [[0.9, 0.1]
+             [0.1, 0.9]]
+
+
+        Pereyra, Gabriel, et al. "Regularizing neural networks by penalizing
+        confident output distributions." arXiv preprint arXiv:1701.06548 (2017).
+        """
+
+        # TODO: replace with mx.nd.eye(batch_size) with mxnet 1.2
+        gold = F.one_hot(F.arange(batch_size), batch_size)
+        labels = gold * (1 - self.smoothing_parameter) + (1 - gold) * self.smoothing_parameter / (batch_size - 1)
+        return labels
+
+
+    def _loss(self, F, x1, x2):
+        """
+        the function computes the kl divergence between the negative distances
+        (internally it compute a softmax casting into probabilities) and the
+        identity matrix.
+
+        This assumes that the two batches are aligned therefore the more similar
+        vector should be the one having the same id.
+
+        Batch1                                Batch2
+
+        President of France                   French President
+        President of US                       American President
+
+        Given the question president of France in batch 1 the model will
+        learn to predict french president comparing it with all the other
+        vectors in batch 2
+        """
+        batch_size = x1.shape[0]
+        labels = self._compute_labels(F, batch_size)
+        distances = self._compute_distances(x1, x2)
+        log_probabilities = F.log_softmax(-distances, axis=1)
+        # multiply for the number of labels to obtain the correct loss (gluon kl_loss averages instead of sum)
+        return self.kl_loss(log_probabilities, labels.as_in_context(distances.context)) * batch_size

Review comment:
       this doesn't work in sym. 1.x will need to be fixed while 2.0 will be switched to deferred compute mode




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org