You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/08 18:35:00 UTC

[GitHub] roywei commented on a change in pull request #12750: [MXNET -1030] Cosine Embedding Loss

roywei commented on a change in pull request #12750: [MXNET -1030] Cosine Embedding Loss
URL: https://github.com/apache/incubator-mxnet/pull/12750#discussion_r223445427
 
 

 ##########
 File path: python/mxnet/gluon/loss.py
 ##########
 @@ -706,3 +706,77 @@ def hybrid_forward(self, F, pred, positive, negative):
                      axis=self._batch_axis, exclude=True)
         loss = F.relu(loss + self._margin)
         return _apply_weighting(F, loss, self._weight, None)
+
+class CosineEmbeddingLoss(Loss):
+    r"""For a target label 1 or -1, vectors target and pred, the function computes the cosine distance
+    between the vectors. This can be interpretted as how similar/dissimilar two input vectors are.
+    .. math::
+        Cosine\_loss = \begin{gather*}
+	                        \begin{cases}
+		                        1 - cos\_sim(pred, target) & \text{if } label = 1\\
+		                        cos\_sim(pred, target) 	   & \text{if } label = -1
+	                        \end{cases}
+                        \end{gather*}
+        If
+        \begin{equation}
+	        pred = p_1x + p_2y + p_3z
+        \end{equation}
+        and
+        \begin{equation}
+	        target = t_1x + t_2y + t_3z
+        \end{equation}\\
+        Cosine Similarity:\\
+        \begin{equation}
+	        cos\_sim = \frac{pred.target}
+					        {||pred||.||target||}
+        \end{equation}
+
+        \begin{equation}
+	        cos\_sim(pred, target) = \frac{p_1.t_1 + p_2.t_2 + p_3.t_3}
+								          {\sqrt{p_1^2 + p_2^2 + p_3^2}.\sqrt{t_1^2 + t_2^2 + t_3^2}}
+        \end{equation}
+
+    `pred`, `target` can have arbitrary shape as long as they have the same number of elements.
+
+    Parameters
+    ----------
+    weight : float or None
+        Global scalar weight for loss.
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    margin : float
+        Margin of separation between correct and incorrect pair.
+
+
+    Inputs:
+    ------
+        - **pred**:   prediction tensor with arbitrary shape
+        - **target**: target tensor with same shape as pred.
+        - **sample_weight**: element-wise weighting tensor. Must be broadcastable
+          to the same shape as pred. For example, if pred has shape (64, 10)
+          and you want to weigh each sample in the batch separately,
+          sample_weight should have shape (64, 1).
+        - label: A 1-D tensor indicating for each pair input and pred, target label is 1 or -1
+
+    Outputs:
+    --------
+        - **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
+    """
+    def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs):
+        super(CosineEmbeddingLoss, self).__init__(weight, batch_axis, **kwargs)
+        self._margin = margin
+
+    def hybrid_forward(self, F, pred, target, label):
+        pred = _reshape_like(F, pred, target)
+        cos_sim = self.cosine_similarity(F, pred, target)
+        y_1 = label == 1
+        y_minus_1 = label == -1
+        cos_sim_a = (1 - cos_sim) * y_1
+        cos_sim_b = F.broadcast_maximum(F.array([0]), y_minus_1 * (cos_sim - self._margin), axis=1)
+        return cos_sim_a + cos_sim_b
+
+    def cosine_similarity(self, F, F1, F2, axis=-1):
 
 Review comment:
   nit: rename F1 and F2 to x, y or x1, x2 to avoid confusion with F. F is for hybridization, F1 and F2 are vectors.
   Also write a small comment, explaining you are computing the cosine similarity here

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services