You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/02 15:48:32 UTC

[GitHub] [incubator-tvm] hanke580 commented on a change in pull request #6370: [TOPI] Add einsum operator

hanke580 commented on a change in pull request #6370:
URL: https://github.com/apache/incubator-tvm/pull/6370#discussion_r482177048



##########
File path: include/tvm/topi/transform.h
##########
@@ -1281,6 +1285,832 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExp
   return compute(output_shape, func, name, tag);
 }
 
+inline Array<PrimExpr> get_stride(const Array<PrimExpr> shape) {
+  size_t ndim = shape.size();
+  int prod = 1;
+  Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
+  for (int i = ndim - 1; i >= 0; i--) {
+    stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
+    prod = prod * GetConstInt(shape[i]);
+  }
+  return stride;
+}
+
+inline Array<PrimExpr> pad(const Array<PrimExpr> shape, int odim) {
+  int ndim = shape.size();
+  CHECK_GE(odim, ndim);
+  Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
+  for (int idim = 0; idim < ndim; ++idim) {
+    ret.Set(idim, shape[idim]);
+  }
+  return ret;
+}
+
+inline int parse_operand_subscripts(const char *subscripts, int length,
+                                    int ndim, int iop, char *op_labels,
+                                    char *label_counts, int *min_label, int *max_label) {
+  int i;
+  int idim = 0;
+  int ellipsis = -1;
+
+  /* Process all labels for this operand */
+  for (i = 0; i < length; ++i) {
+    int label = subscripts[i];
+
+    /* A proper label for an axis. */
+    if (label > 0 && isalpha(label)) {
+      /* Check we don't exceed the operator dimensions. */
+      CHECK(idim < ndim)
+        << "einstein sum subscripts string contains "
+        << "too many subscripts for operand "
+        << iop;
+
+      op_labels[idim++] = label;
+      if (label < *min_label) {
+        *min_label = label;
+      }
+      if (label > *max_label) {
+        *max_label = label;
+      }
+      label_counts[label]++;
+    } else if (label == '.') {
+      /* The beginning of the ellipsis. */
+      /* Check it's a proper ellipsis. */
+      CHECK(!(ellipsis != -1 || i + 2 >= length
+              || subscripts[++i] != '.' || subscripts[++i] != '.'))
+        << "einstein sum subscripts string contains a "
+        << "'.' that is not part of an ellipsis ('...') "
+        << "in operand "
+        << iop;
+
+      ellipsis = idim;
+    } else {
+        CHECK(label == ' ')
+          << "invalid subscript '" << static_cast<char>(label)
+          << "' in einstein sum "
+          << "subscripts string, subscripts must "
+          << "be letters";
+    }
+  }
+
+  /* No ellipsis found, labels must match dimensions exactly. */
+  if (ellipsis == -1) {
+    CHECK(idim == ndim)
+      << "operand has more dimensions than subscripts "
+      << "given in einstein sum, but no '...' ellipsis "
+      << "provided to broadcast the extra dimensions.";
+  } else if (idim < ndim) {
+    /* Ellipsis found, may have to add broadcast dimensions. */
+    /* Move labels after ellipsis to the end. */
+    for (i = 0; i < idim - ellipsis; ++i) {
+      op_labels[ndim - i - 1] = op_labels[idim - i - 1];
+    }
+    /* Set all broadcast dimensions to zero. */
+    for (i = 0; i < ndim - idim; ++i) {
+      op_labels[ellipsis + i] = 0;
+    }
+  }
+
+  /*
+   * Find any labels duplicated for this operand, and turn them
+   * into negative offsets to the axis to merge with.
+   *
+   * In C, the char type may be signed or unsigned, but with
+   * twos complement arithmetic the char is ok either way here, and
+   * later where it matters the char is cast to a signed char.
+   */
+  for (idim = 0; idim < ndim - 1; ++idim) {
+    int label = op_labels[idim];
+    /* If it is a proper label, find any duplicates of it. */
+    if (label > 0) {
+      /* Search for the next matching label. */
+      char *next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
+
+      while (next != nullptr) {
+        /* The offset from next to op_labels[idim] (negative). */
+        *next = static_cast<char>((op_labels + idim) - next);
+        /* Search for the next matching label. */
+        next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
+      }
+    }
+  }
+  return 0;
+}
+
+inline int parse_output_subscripts(const char *subscripts, int length,
+                                   int ndim_broadcast,
+                                   const char *label_counts, char *out_labels) {
+  int i, bdim;
+  int ndim = 0;
+  int ellipsis = 0;
+
+  /* Process all the output labels. */
+  for (i = 0; i < length; ++i) {
+    int label = subscripts[i];
+
+    /* A proper label for an axis. */
+    if (label > 0 && isalpha(label)) {
+      /* Check that it doesn't occur again. */
+      CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
+        << "einstein sum subscripts string includes "
+        << "output subscript '" << static_cast<char>(label)
+        << "' multiple times";
+
+      /* Check that it was used in the inputs. */
+      CHECK(label_counts[label] != 0)
+        << "einstein sum subscripts string included "
+        << "output subscript '" << static_cast<char>(label)
+        << "' which never appeared "
+        << "in an input";
+
+      /* Check that there is room in out_labels for this label. */
+      CHECK(ndim < NPY_MAXDIMS)
+        << "einstein sum subscripts string contains "
+        << "too many subscripts in the output";
+
+      out_labels[ndim++] = label;
+    } else if (label == '.') {
+      /* The beginning of the ellipsis. */
+      /* Check it is a proper ellipsis. */
+      CHECK(!(ellipsis || i + 2 >= length
+              || subscripts[++i] != '.' || subscripts[++i] != '.'))
+        << "einstein sum subscripts string "
+        << "contains a '.' that is not part of "
+        << "an ellipsis ('...') in the output";
+
+      /* Check there is room in out_labels for broadcast dims. */
+      CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS)
+        << "einstein sum subscripts string contains "
+        << "too many subscripts in the output";
+
+      ellipsis = 1;
+      for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
+        out_labels[ndim++] = 0;
+      }
+    } else {
+      CHECK(label == ' ')
+        << "invalid subscript '" << static_cast<char>(label)
+        << "' in einstein sum "
+        << "subscripts string, subscripts must "
+        << "be letters";
+    }
+  }
+
+  /* If no ellipsis was found there should be no broadcast dimensions. */
+  CHECK(!(!ellipsis && ndim_broadcast > 0))
+    << "output has more dimensions than subscripts "
+    << "given in einstein sum, but no '...' ellipsis "
+    << "provided to broadcast the extra dimensions.";
+
+  return ndim;
+}
+
+inline void get_combined_dims_view(const Tensor& op, int iop,

Review comment:
       Fixed, thx




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org