You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/27 01:04:47 UTC

[GitHub] [tvm] vinx13 opened a new pull request, #12913: [TOPI] Implement Einsum with reduction axes

vinx13 opened a new pull request, #12913:
URL: https://github.com/apache/tvm/pull/12913

   The current Einsum doesn't utilize `ReduceOp` for reduction. It instead unrolls the whole reduction when defining the computation. It makes the generated result not suitable for analysis as it doesn't contain reduction information. It's also broken for large reductions. This PR provided a new implementation that generates the reduction axes correctly.
   Graph level optimizations might be needed for efficient Einsum. This PR only provides compute definition for further analysis and optimizations.
   
   cc @spectrometerHBH @MasterJH5574 @junrushao @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #12913: [TOPI] Implement Einsum with reduction axes

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #12913:
URL: https://github.com/apache/tvm/pull/12913#discussion_r981688984


##########
include/tvm/topi/einsum.h:
##########
@@ -49,623 +49,15 @@ namespace topi {
 using namespace tvm::te;
 using namespace topi::detail;
 
-/*!
- * \brief Compute the stride of the given shape.
- *
- * \param shape for the operation.
- *
- * \return the stride of the shape.
- */
-inline Array<PrimExpr> GetStride(const Array<PrimExpr> shape) {
-  size_t ndim = shape.size();
-  int prod = 1;
-  Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
-  for (int i = ndim - 1; i >= 0; i--) {
-    stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
-    prod = prod * GetConstInt(shape[i]);
-  }
-  return stride;
-}
-
-/*!
- * \brief Pad the shape with 1.
- *
- * \param shape the input shape to be padded
- * \param odim the padding size of the objective shape.
- *
- * \return the padded shape.
- */
-inline Array<PrimExpr> Pad(const Array<PrimExpr> shape, int odim) {
-  int ndim = shape.size();
-  CHECK_GE(odim, ndim);
-  Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
-  for (int idim = 0; idim < ndim; ++idim) {
-    ret.Set(idim, shape[idim]);
-  }
-  return ret;
-}
-
-/*!
- * \brief Parse the subscripts for one operand into an output of 'ndim' labels.
- *
- * \param subscripts the subscripts for to be parsed.
- * \param length subscripts[0: length] represents the current operand.
- * \param ndim the ndim of current operand.
- * \param iop the index of the operand.
- * \param op_labels the parsing result.
- *        For Example:
- *           subscripts="abbcbc",  ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2].
- *           subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99].
- * \param label_counts Count the number the label appears.
- * \param min_label Save the minimal label according to ASCII.
- * \param max_label Save the maximal label according to ASCII.
- *
- * \return 0.
- */
-inline int ParseOperandSubscripts(const char* subscripts, int length, int ndim, int iop,
-                                  char* op_labels, char* label_counts, int* min_label,
-                                  int* max_label) {
-  int i;
-  int idim = 0;
-  int ellipsis = -1;
-
-  /* Process all labels for this operand */
-  for (i = 0; i < length; ++i) {
-    int label = subscripts[i];
-
-    /* A proper label for an axis. */
-    if (label > 0 && isalpha(label)) {
-      /* Check we don't exceed the operator dimensions. */
-      CHECK(idim < ndim) << "einstein sum subscripts string contains "
-                         << "too many subscripts for operand " << iop;
-
-      op_labels[idim++] = label;
-      if (label < *min_label) {
-        *min_label = label;
-      }
-      if (label > *max_label) {
-        *max_label = label;
-      }
-      label_counts[label]++;
-    } else if (label == '.') {
-      /* The beginning of the ellipsis. */
-      /* Check it's a proper ellipsis. */
-      CHECK(
-          !(ellipsis != -1 || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
-          << "einstein sum subscripts string contains a "
-          << "'.' that is not part of an ellipsis ('...') "
-          << "in operand " << iop;
-
-      ellipsis = idim;
-    } else {
-      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
-                          << "' in einstein sum "
-                          << "subscripts string, subscripts must "
-                          << "be letters";
-    }
-  }
-
-  /* No ellipsis found, labels must match dimensions exactly. */
-  if (ellipsis == -1) {
-    CHECK(idim == ndim) << "operand has more dimensions than subscripts "
-                        << "given in einstein sum, but no '...' ellipsis "
-                        << "provided to broadcast the extra dimensions.";
-  } else if (idim < ndim) {
-    /* Ellipsis found, may have to add broadcast dimensions. */
-    /* Move labels after ellipsis to the end. */
-    for (i = 0; i < idim - ellipsis; ++i) {
-      op_labels[ndim - i - 1] = op_labels[idim - i - 1];
-    }
-    /* Set all broadcast dimensions to zero. */
-    for (i = 0; i < ndim - idim; ++i) {
-      op_labels[ellipsis + i] = 0;
-    }
-  }
-
-  /*
-   * Find any labels duplicated for this operand, and turn them
-   * into negative offsets to the axis to merge with.
-   *
-   * In C, the char type may be signed or unsigned, but with
-   * twos complement arithmetic the char is ok either way here, and
-   * later where it matters the char is cast to a signed char.
-   */
-  for (idim = 0; idim < ndim - 1; ++idim) {
-    int label = op_labels[idim];
-    /* If it is a proper label, find any duplicates of it. */
-    if (label > 0) {
-      /* Search for the next matching label. */
-      char* next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
-
-      while (next != nullptr) {
-        /* The offset from next to op_labels[idim] (negative). */
-        *next = static_cast<char>((op_labels + idim) - next);
-        /* Search for the next matching label. */
-        next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
-      }
-    }
-  }
-  return 0;
-}
-
-/*!
- * \brief Parse the subscripts for the output into an output that includes 'ndim_broadcast'
- *        unlabeled dimensions.
- *
- * \param subscripts the subscripts for to be parsed.
- * \param length subscripts[0: length] represents the output operand.
- * \param ndim_broadcast the broadcast dimension number.
- * \param label_counts Count the number the label appears.
- * \param out_labels similar to the op_labels in ParseOperandSubscripts, for each
- *        dimension, the ASCII code of the corresponding label. zero for the broadcasting dim.
- *
- * \return the total number of output dimensions or -1 if there is an error.
- */
-inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_broadcast,
-                                 const char* label_counts, char* out_labels) {
-  int i, bdim;
-  int ndim = 0;
-  int ellipsis = 0;
-
-  /* Process all the output labels. */
-  for (i = 0; i < length; ++i) {
-    int label = subscripts[i];
-
-    /* A proper label for an axis. */
-    if (label > 0 && isalpha(label)) {
-      /* Check that it doesn't occur again. */
-      CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
-          << "einstein sum subscripts string includes "
-          << "output subscript '" << static_cast<char>(label) << "' multiple times";
-
-      /* Check that it was used in the inputs. */
-      CHECK(label_counts[label] != 0)
-          << "einstein sum subscripts string included "
-          << "output subscript '" << static_cast<char>(label) << "' which never appeared "
-          << "in an input";
-
-      /* Check that there is room in out_labels for this label. */
-      CHECK(ndim < NPY_MAXDIMS) << "einstein sum subscripts string contains "
-                                << "too many subscripts in the output";
-
-      out_labels[ndim++] = label;
-    } else if (label == '.') {
-      /* The beginning of the ellipsis. */
-      /* Check it is a proper ellipsis. */
-      CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
-          << "einstein sum subscripts string "
-          << "contains a '.' that is not part of "
-          << "an ellipsis ('...') in the output";
-
-      /* Check there is room in out_labels for broadcast dims. */
-      CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) << "einstein sum subscripts string contains "
-                                                  << "too many subscripts in the output";
-
-      ellipsis = 1;
-      for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
-        out_labels[ndim++] = 0;
-      }
-    } else {
-      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
-                          << "' in einstein sum "
-                          << "subscripts string, subscripts must "
-                          << "be letters";
-    }
-  }
-
-  /* If no ellipsis was found there should be no broadcast dimensions. */
-  CHECK(!(!ellipsis && ndim_broadcast > 0)) << "output has more dimensions than subscripts "
-                                            << "given in einstein sum, but no '...' ellipsis "
-                                            << "provided to broadcast the extra dimensions.";
-
-  return ndim;
-}
-
-/*!
- * \brief If any dimensions are combined, create a view that combines them.
- *        Shows in newshape and newstride.
- *
- * \param op the operand tensor.
- * \param iop the index of the operand.
- * \param labels the op_labels fot the operand. Like [97, 98, -2] for "aba".
- * \param newshape The combined shape.
- * \param newstride The combined stride.
- *
- * For example:
- *  "aba -> ab",              shape = [2,3,2] stride = [6,2,1]
- *  op_labels = [97, 98, -2], newshape = [2,3], newstride = [7,2]
- */
-inline void GetCombinedDimsView(const Tensor& op, int iop, char* labels, Array<PrimExpr>* newshape,
-                                Array<PrimExpr>* newstride) {
-  int idim, ndim, icombine, combineoffset;
-  int icombinemap[NPY_MAXDIMS];
-  int newdim;
-
-  Array<PrimExpr> shape = op->shape;
-  Array<PrimExpr> stride = GetStride(shape);
-  ndim = op.ndim();
-  newdim = newshape->size();
-
-  /* Initialize the dimensions and strides to zero */
-  for (idim = 0; idim < newdim; ++idim) {
-    newshape->Set(idim, 0);
-    newstride->Set(idim, 0);
-  }
-
-  /* Copy the dimensions and strides, except when collapsing */
-  icombine = 0;
-  for (idim = 0; idim < ndim; ++idim) {
-    /*
-     * The char type may be either signed or unsigned, we
-     * need it to be signed here.
-     */
-    int label = (signed char)labels[idim];
-    /* If this label says to merge axes, get the actual label */
-    if (label < 0) {
-      combineoffset = label;
-      label = labels[idim + label];
-    } else {
-      combineoffset = 0;
-      if (icombine != idim) {
-        labels[icombine] = labels[idim];
-      }
-      icombinemap[idim] = icombine;
-    }
-    /* If the label is 0, it's an unlabeled broadcast dimension */
-    if (label == 0) {
-      newshape->Set(icombine, shape[idim]);
-      newstride->Set(icombine, stride[idim]);
-    } else {
-      /* Update the combined axis dimensions and strides */
-      int i = icombinemap[idim + combineoffset];
-      CHECK(!((combineoffset < 0) &&
-              GetConstInt((*newshape)[i] != 0 && (*newshape)[i] != shape[idim])))
-          << "dimensions in operand " << iop << " for collapsing index '" << label
-          << "' don't match (" << GetConstInt((*newshape)[i]) << " != " << shape[idim] << ")";
-      newshape->Set(i, shape[idim]);
-      newstride->Set(i, (*newstride)[i] + stride[idim]);
-    }
-
-    /* If the label didn't say to combine axes, increment dest i */
-    if (combineoffset == 0) {
-      icombine++;
-    }
-  }
-}
-
-/*!
- * \brief Prepare the operand axes to match each stride or shape pair.
- *
- * \param ndim the ndim of the operand tensor.
- * \param iop the index of the operand.
- * \param labels the op_labels fot the operand. [97, 98, -1, 99, -3, -2] for "abbcbc".
- * \param axes The matched axes to be calculated.
- * \param ndim_iter the dimension of iterating. Subscripts "ab, bc -> ac" ndim_iter = 3.
- * \param iter_labels output_labels with the iterating label. ['a', 'c', 'b'] for the case above.
- */
-inline static int PrepareOpAxes(int ndim, int iop, char* labels, int* axes, int ndim_iter,
-                                char* iter_labels) {
-  int i, label, ibroadcast;
-
-  ibroadcast = ndim - 1;
-  for (i = ndim_iter - 1; i >= 0; --i) {
-    label = iter_labels[i];
-    /*
-     * If it's an unlabeled broadcast dimension, choose
-     * the next broadcast dimension from the operand.
-     */
-    if (label == 0) {
-      while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
-        --ibroadcast;
-      }
-      /*
-       * If we used up all the operand broadcast dimensions,
-       * extend it with a "newaxis"
-       */
-      if (ibroadcast < 0) {
-        axes[i] = -1;
-      } else {
-        /* Otherwise map to the broadcast axis */
-        axes[i] = ibroadcast;
-        --ibroadcast;
-      }
-    } else {
-      /* It's a labeled dimension, find the matching one */
-      char* match = reinterpret_cast<char*>(memchr(labels, label, ndim));
-      /* If the op doesn't have the label, broadcast it */
-      if (match == nullptr) {
-        axes[i] = -1;
-      } else {
-        /* Otherwise use it */
-        axes[i] = match - labels;
-      }
-    }
-  }
-  return 0;
-}
-
-/*!
- * \brief Count SubString.
- * \param str the object string
- * \param sub the pattern string
- *
- * \return number of substring
- */
-inline int CountSubstring(const std::string& str, const std::string& sub) {
-  int count = 0;
-  std::string::size_type pos = 0;
-  while ((pos = str.find(sub, pos)) != std::string::npos) {
-    ++count;
-    pos += sub.length();
-  }
-  return count;
-}
-
-/*!
- * \brief Transfer string to.
- * \param str input string.
- *
- * \return bitset.
- */
-inline std::bitset<LABELRANGE> Str2Set(const std::string& str) {
-  std::bitset<LABELRANGE> ret;
-  for (const char& c : str) {
-    ret.set(static_cast<int>(c));
-  }
-  return ret;
-}
-
-/*!
- * \brief Split str according to substring.
- * \param str input string.
- * \param sub the split pattern string.
- *
- * \return vector contains the splited substring.
- */
-inline std::vector<std::string> Split(const std::string& str, const std::string& sub) {
-  std::string::size_type pos = 0;
-  std::string::size_type start = 0;
-  std::vector<std::string> ret;
-  while ((pos = str.find(sub, start)) != std::string::npos) {
-    ret.push_back(str.substr(start, pos - start));
-    start = pos + sub.length();
-  }
-  ret.push_back(str.substr(start));
-  return ret;
-}
-
-/*!
- * \brief Parse the input subscripts into a vector of strings.
- * \param subscripts input subscripts.
- * \param operands operand tensors.
- *
- * \return vector of strings, vector[0] represents the input part, vector[1] represents the output.
- * if no output, the vector[1] is NULL.
- * "ab, bc -> ac" => ["ab,bc", "ac"]
- */
-inline std::tuple<std::string, std::string> ParseEinsumInput(
-    std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
-  const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
-  std::bitset<LABELRANGE> einsum_symbols_set;
-  for (const char& c : einsum_symbols) {
-    einsum_symbols_set.set(c);
-  }
-
-  CHECK_NE(operands.size(), 0U) << "No input operands";
-
-  auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' ');
-  subscripts.erase(end_pos, subscripts.end());
-
-  // Ensure all characters are valid
-  for (const char& c : subscripts) {
-    if (c == '.' || c == ',' || c == '-' || c == '>') {
-      continue;
-    }
-    CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-  }
-
-  // Check for proper "->"
-  if (subscripts.find('-') != std::string::npos || subscripts.find('>') != std::string::npos) {
-    bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 ||
-                    std::count(subscripts.begin(), subscripts.end(), '>') > 1);
-    CHECK(!invalid && CountSubstring(subscripts, "->") == 1)
-        << "Subscripts can only contain one '->'.";
-  }
-
-  // Parse ellipses
-  if (subscripts.find('.') != std::string::npos) {
-    std::string used = subscripts;
-    used.erase(
-        std::remove_if(used.begin(), used.end(),
-                       [](const char& c) { return c == '.' || c == ',' || c == '-' || c == '>'; }),
-        used.end());
-
-    std::bitset<LABELRANGE> used_set = Str2Set(used);
-    std::string ellipse_inds = "";
-    for (const char& c : einsum_symbols) {
-      if (!used_set.test(static_cast<int>(c))) {
-        ellipse_inds.append(1, c);
-      }
-    }
-    int longest = 0;
-    std::string input_tmp, output_sub;
-    std::vector<std::string> split_subscripts;
-    bool out_sub;
-
-    if (subscripts.find("->") != std::string::npos) {
-      std::vector<std::string> tmp = Split(subscripts, "->");
-      input_tmp = tmp[0];
-      output_sub = tmp[1];
-      split_subscripts = Split(input_tmp, ",");
-      out_sub = true;
-    } else {
-      split_subscripts = Split(subscripts, ",");
-      out_sub = false;
-    }
-
-    size_t size_split_subscripts = split_subscripts.size();
-    subscripts = "";
-    for (size_t i = 0; i < size_split_subscripts; ++i) {
-      const std::string& sub = split_subscripts[i];
-      if (sub.find('.') != std::string::npos) {
-        CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) << "Invalid Ellipses";
-        CHECK_EQ(CountSubstring(sub, "..."), 1) << "Invalid Ellipses";
-
-        // Take into account numerical values
-        int ellipse_count = 0;
-        if (operands[i].size() == 0) {
-          ellipse_count = 0;
-        } else {
-          ellipse_count = std::max(operands[i].size(), static_cast<size_t>(1));
-          ellipse_count -= sub.length() - 3;
-        }
-
-        if (ellipse_count > longest) {
-          longest = ellipse_count;
-        }
-
-        CHECK_GE(ellipse_count, 0) << "Ellipses lengths do not match.";
-        if (ellipse_count == 0) {
-          split_subscripts[i].erase(sub.find("..."), 3);
-        } else {
-          std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count);
-          split_subscripts[i].replace(sub.find("..."), 3, rep_inds);
-        }
-      }
-      subscripts += split_subscripts[i];
-      if (i + 1 < size_split_subscripts) {
-        subscripts += ",";
-      }
-    }
-    std::string out_ellipse;
-    if (longest == 0) {
-      out_ellipse = "";
-    } else {
-      out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
-    }
-
-    if (out_sub) {
-      output_sub.replace(output_sub.find("..."), 3, out_ellipse);
-      subscripts += "->" + output_sub;
-    } else {
-      // Special care for outputless ellipses
-      std::bitset<LABELRANGE> out_ellipse_set = Str2Set(out_ellipse);
-      std::string tmp_subscripts = subscripts, output_subscript = "";
-      size_t len_tmp_subscripts = tmp_subscripts.length();
-      std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
-      for (size_t i = 0; i < len_tmp_subscripts; ++i) {
-        const char& c = tmp_subscripts[i];
-        if (c == ',') {
-          continue;
-        }
-        CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-        if ((i == 0 || tmp_subscripts[i - 1] != c) &&
-            (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
-            !out_ellipse_set.test(c)) {
-          output_subscript.append(1, c);
-        }
-      }
-      subscripts += "->" + out_ellipse + output_subscript;
-    }
-  }
-
-  // Build output string if does not exist
-  std::tuple<std::string, std::string> ret;
-  if (subscripts.find("->") != std::string::npos) {
-    std::vector<std::string> tmp(2);
-    tmp = Split(subscripts, "->");
-    ret = std::make_tuple(tmp[0], tmp[1]);
-  } else {
-    std::string first = subscripts;
-    std::string second = "";
-    // Build output subscripts
-    std::string tmp_subscripts = subscripts;
-    size_t len_tmp_subscripts = tmp_subscripts.length();
-    std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
-    for (size_t i = 0; i < len_tmp_subscripts; ++i) {
-      const char& c = tmp_subscripts[i];
-      if (c == ',') {
-        continue;
-      }
-      CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-      if ((i == 0 || tmp_subscripts[i - 1] != c) &&
-          (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
-        second.append(1, c);
-      }
-    }
-    ret = std::make_tuple(first, second);
-  }
-
-  // Make sure output subscripts are in the input
-  std::bitset<LABELRANGE> input_subscripts_set = Str2Set(std::get<0>(ret));
-  for (const char& c : std::get<1>(ret)) {
-    CHECK(input_subscripts_set.test(c))
-        << "Output character " << c << " did not appear in the input";
-  }
-
-  // Make sure number operands is equivalent to the number of terms
-  CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') + 1, operands.size())
-      << "Number of einsum subscripts must be equal to the "
-      << "number of operands.";
-
-  return ret;
-}
-
 /*!
  * \brief Compute the shape of the output.
  * \param subscripts input subscripts.
  * \param operands operand tensors.
  *
  * \return the shape of the output.
  */
-inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
-                                        const std::vector<Array<PrimExpr>>& operands) {
-  // Parsing
-  std::tuple<std::string, std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);
-
-  // Build a few useful list and sets
-  std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts), ",");
-  size_t isize = input_list.size();
-
-  // Get length of each unique dimension and ensure all dimensions are correct
-  int dimension_dict[LABELRANGE];
-  memset(dimension_dict, -1, sizeof(dimension_dict));
-  for (size_t i = 0; i < isize; ++i) {
-    const std::string& term = input_list[i];
-    const Array<PrimExpr>& sh = operands[i];
-    CHECK_EQ(sh.size(), term.length())
-        << "Einstein sum subscript " << input_list[i] << " does not contain the "
-        << "correct number of indices for operand " << i << ".";
-    size_t len_term = term.length();
-    for (size_t j = 0; j < len_term; ++j) {
-      int64_t dim = GetConstInt(sh[j]);
-      const char& c = term[j];
-
-      if (dimension_dict[static_cast<int>(c)] != -1) {
-        // For broadcasting cases we always want the largest dim size
-        if (dimension_dict[static_cast<int>(c)] == 1) {
-          dimension_dict[static_cast<int>(c)] = dim;
-        }
-        CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
-            << "Size of label '" << c << "' for operand  " << i << " ("
-            << dimension_dict[static_cast<int>(c)] << ") does not match previous terms (" << dim
-            << ").";
-      } else {
-        dimension_dict[static_cast<int>(c)] = dim;
-      }
-    }
-  }
-
-  // Get oshape
-  const std::string& output_str = std::get<1>(parsed_subscripts);
-  size_t odim = output_str.size();
-  Array<PrimExpr> oshape(odim, -1);
-  for (size_t i = 0; i < odim; ++i) {
-    oshape.Set(i, dimension_dict[static_cast<int>(output_str[i])]);
-  }
-  // Neglecting oshape assign check temporally
-  return oshape;
-}
+Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,

Review Comment:
   this is the name used previously, indeed there's no reason using numpy prefix, we can probably find a better name



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #12913: [TOPI] Implement Einsum with reduction axes

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on code in PR #12913:
URL: https://github.com/apache/tvm/pull/12913#discussion_r981687088


##########
src/topi/transform.cc:
##########
@@ -173,10 +173,6 @@ TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv)
   }
 });
 
-TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = einsum(args[0], args[1]);
-});
-

Review Comment:
   Is this no longer used?



##########
include/tvm/topi/einsum.h:
##########
@@ -49,623 +49,15 @@ namespace topi {
 using namespace tvm::te;
 using namespace topi::detail;
 
-/*!
- * \brief Compute the stride of the given shape.
- *
- * \param shape for the operation.
- *
- * \return the stride of the shape.
- */
-inline Array<PrimExpr> GetStride(const Array<PrimExpr> shape) {
-  size_t ndim = shape.size();
-  int prod = 1;
-  Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
-  for (int i = ndim - 1; i >= 0; i--) {
-    stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
-    prod = prod * GetConstInt(shape[i]);
-  }
-  return stride;
-}
-
-/*!
- * \brief Pad the shape with 1.
- *
- * \param shape the input shape to be padded
- * \param odim the padding size of the objective shape.
- *
- * \return the padded shape.
- */
-inline Array<PrimExpr> Pad(const Array<PrimExpr> shape, int odim) {
-  int ndim = shape.size();
-  CHECK_GE(odim, ndim);
-  Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
-  for (int idim = 0; idim < ndim; ++idim) {
-    ret.Set(idim, shape[idim]);
-  }
-  return ret;
-}
-
-/*!
- * \brief Parse the subscripts for one operand into an output of 'ndim' labels.
- *
- * \param subscripts the subscripts for to be parsed.
- * \param length subscripts[0: length] represents the current operand.
- * \param ndim the ndim of current operand.
- * \param iop the index of the operand.
- * \param op_labels the parsing result.
- *        For Example:
- *           subscripts="abbcbc",  ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2].
- *           subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99].
- * \param label_counts Count the number the label appears.
- * \param min_label Save the minimal label according to ASCII.
- * \param max_label Save the maximal label according to ASCII.
- *
- * \return 0.
- */
-inline int ParseOperandSubscripts(const char* subscripts, int length, int ndim, int iop,
-                                  char* op_labels, char* label_counts, int* min_label,
-                                  int* max_label) {
-  int i;
-  int idim = 0;
-  int ellipsis = -1;
-
-  /* Process all labels for this operand */
-  for (i = 0; i < length; ++i) {
-    int label = subscripts[i];
-
-    /* A proper label for an axis. */
-    if (label > 0 && isalpha(label)) {
-      /* Check we don't exceed the operator dimensions. */
-      CHECK(idim < ndim) << "einstein sum subscripts string contains "
-                         << "too many subscripts for operand " << iop;
-
-      op_labels[idim++] = label;
-      if (label < *min_label) {
-        *min_label = label;
-      }
-      if (label > *max_label) {
-        *max_label = label;
-      }
-      label_counts[label]++;
-    } else if (label == '.') {
-      /* The beginning of the ellipsis. */
-      /* Check it's a proper ellipsis. */
-      CHECK(
-          !(ellipsis != -1 || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
-          << "einstein sum subscripts string contains a "
-          << "'.' that is not part of an ellipsis ('...') "
-          << "in operand " << iop;
-
-      ellipsis = idim;
-    } else {
-      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
-                          << "' in einstein sum "
-                          << "subscripts string, subscripts must "
-                          << "be letters";
-    }
-  }
-
-  /* No ellipsis found, labels must match dimensions exactly. */
-  if (ellipsis == -1) {
-    CHECK(idim == ndim) << "operand has more dimensions than subscripts "
-                        << "given in einstein sum, but no '...' ellipsis "
-                        << "provided to broadcast the extra dimensions.";
-  } else if (idim < ndim) {
-    /* Ellipsis found, may have to add broadcast dimensions. */
-    /* Move labels after ellipsis to the end. */
-    for (i = 0; i < idim - ellipsis; ++i) {
-      op_labels[ndim - i - 1] = op_labels[idim - i - 1];
-    }
-    /* Set all broadcast dimensions to zero. */
-    for (i = 0; i < ndim - idim; ++i) {
-      op_labels[ellipsis + i] = 0;
-    }
-  }
-
-  /*
-   * Find any labels duplicated for this operand, and turn them
-   * into negative offsets to the axis to merge with.
-   *
-   * In C, the char type may be signed or unsigned, but with
-   * twos complement arithmetic the char is ok either way here, and
-   * later where it matters the char is cast to a signed char.
-   */
-  for (idim = 0; idim < ndim - 1; ++idim) {
-    int label = op_labels[idim];
-    /* If it is a proper label, find any duplicates of it. */
-    if (label > 0) {
-      /* Search for the next matching label. */
-      char* next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
-
-      while (next != nullptr) {
-        /* The offset from next to op_labels[idim] (negative). */
-        *next = static_cast<char>((op_labels + idim) - next);
-        /* Search for the next matching label. */
-        next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
-      }
-    }
-  }
-  return 0;
-}
-
-/*!
- * \brief Parse the subscripts for the output into an output that includes 'ndim_broadcast'
- *        unlabeled dimensions.
- *
- * \param subscripts the subscripts for to be parsed.
- * \param length subscripts[0: length] represents the output operand.
- * \param ndim_broadcast the broadcast dimension number.
- * \param label_counts Count the number the label appears.
- * \param out_labels similar to the op_labels in ParseOperandSubscripts, for each
- *        dimension, the ASCII code of the corresponding label. zero for the broadcasting dim.
- *
- * \return the total number of output dimensions or -1 if there is an error.
- */
-inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_broadcast,
-                                 const char* label_counts, char* out_labels) {
-  int i, bdim;
-  int ndim = 0;
-  int ellipsis = 0;
-
-  /* Process all the output labels. */
-  for (i = 0; i < length; ++i) {
-    int label = subscripts[i];
-
-    /* A proper label for an axis. */
-    if (label > 0 && isalpha(label)) {
-      /* Check that it doesn't occur again. */
-      CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
-          << "einstein sum subscripts string includes "
-          << "output subscript '" << static_cast<char>(label) << "' multiple times";
-
-      /* Check that it was used in the inputs. */
-      CHECK(label_counts[label] != 0)
-          << "einstein sum subscripts string included "
-          << "output subscript '" << static_cast<char>(label) << "' which never appeared "
-          << "in an input";
-
-      /* Check that there is room in out_labels for this label. */
-      CHECK(ndim < NPY_MAXDIMS) << "einstein sum subscripts string contains "
-                                << "too many subscripts in the output";
-
-      out_labels[ndim++] = label;
-    } else if (label == '.') {
-      /* The beginning of the ellipsis. */
-      /* Check it is a proper ellipsis. */
-      CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
-          << "einstein sum subscripts string "
-          << "contains a '.' that is not part of "
-          << "an ellipsis ('...') in the output";
-
-      /* Check there is room in out_labels for broadcast dims. */
-      CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) << "einstein sum subscripts string contains "
-                                                  << "too many subscripts in the output";
-
-      ellipsis = 1;
-      for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
-        out_labels[ndim++] = 0;
-      }
-    } else {
-      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
-                          << "' in einstein sum "
-                          << "subscripts string, subscripts must "
-                          << "be letters";
-    }
-  }
-
-  /* If no ellipsis was found there should be no broadcast dimensions. */
-  CHECK(!(!ellipsis && ndim_broadcast > 0)) << "output has more dimensions than subscripts "
-                                            << "given in einstein sum, but no '...' ellipsis "
-                                            << "provided to broadcast the extra dimensions.";
-
-  return ndim;
-}
-
-/*!
- * \brief If any dimensions are combined, create a view that combines them.
- *        Shows in newshape and newstride.
- *
- * \param op the operand tensor.
- * \param iop the index of the operand.
- * \param labels the op_labels fot the operand. Like [97, 98, -2] for "aba".
- * \param newshape The combined shape.
- * \param newstride The combined stride.
- *
- * For example:
- *  "aba -> ab",              shape = [2,3,2] stride = [6,2,1]
- *  op_labels = [97, 98, -2], newshape = [2,3], newstride = [7,2]
- */
-inline void GetCombinedDimsView(const Tensor& op, int iop, char* labels, Array<PrimExpr>* newshape,
-                                Array<PrimExpr>* newstride) {
-  int idim, ndim, icombine, combineoffset;
-  int icombinemap[NPY_MAXDIMS];
-  int newdim;
-
-  Array<PrimExpr> shape = op->shape;
-  Array<PrimExpr> stride = GetStride(shape);
-  ndim = op.ndim();
-  newdim = newshape->size();
-
-  /* Initialize the dimensions and strides to zero */
-  for (idim = 0; idim < newdim; ++idim) {
-    newshape->Set(idim, 0);
-    newstride->Set(idim, 0);
-  }
-
-  /* Copy the dimensions and strides, except when collapsing */
-  icombine = 0;
-  for (idim = 0; idim < ndim; ++idim) {
-    /*
-     * The char type may be either signed or unsigned, we
-     * need it to be signed here.
-     */
-    int label = (signed char)labels[idim];
-    /* If this label says to merge axes, get the actual label */
-    if (label < 0) {
-      combineoffset = label;
-      label = labels[idim + label];
-    } else {
-      combineoffset = 0;
-      if (icombine != idim) {
-        labels[icombine] = labels[idim];
-      }
-      icombinemap[idim] = icombine;
-    }
-    /* If the label is 0, it's an unlabeled broadcast dimension */
-    if (label == 0) {
-      newshape->Set(icombine, shape[idim]);
-      newstride->Set(icombine, stride[idim]);
-    } else {
-      /* Update the combined axis dimensions and strides */
-      int i = icombinemap[idim + combineoffset];
-      CHECK(!((combineoffset < 0) &&
-              GetConstInt((*newshape)[i] != 0 && (*newshape)[i] != shape[idim])))
-          << "dimensions in operand " << iop << " for collapsing index '" << label
-          << "' don't match (" << GetConstInt((*newshape)[i]) << " != " << shape[idim] << ")";
-      newshape->Set(i, shape[idim]);
-      newstride->Set(i, (*newstride)[i] + stride[idim]);
-    }
-
-    /* If the label didn't say to combine axes, increment dest i */
-    if (combineoffset == 0) {
-      icombine++;
-    }
-  }
-}
-
-/*!
- * \brief Prepare the operand axes to match each stride or shape pair.
- *
- * \param ndim the ndim of the operand tensor.
- * \param iop the index of the operand.
- * \param labels the op_labels fot the operand. [97, 98, -1, 99, -3, -2] for "abbcbc".
- * \param axes The matched axes to be calculated.
- * \param ndim_iter the dimension of iterating. Subscripts "ab, bc -> ac" ndim_iter = 3.
- * \param iter_labels output_labels with the iterating label. ['a', 'c', 'b'] for the case above.
- */
-inline static int PrepareOpAxes(int ndim, int iop, char* labels, int* axes, int ndim_iter,
-                                char* iter_labels) {
-  int i, label, ibroadcast;
-
-  ibroadcast = ndim - 1;
-  for (i = ndim_iter - 1; i >= 0; --i) {
-    label = iter_labels[i];
-    /*
-     * If it's an unlabeled broadcast dimension, choose
-     * the next broadcast dimension from the operand.
-     */
-    if (label == 0) {
-      while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
-        --ibroadcast;
-      }
-      /*
-       * If we used up all the operand broadcast dimensions,
-       * extend it with a "newaxis"
-       */
-      if (ibroadcast < 0) {
-        axes[i] = -1;
-      } else {
-        /* Otherwise map to the broadcast axis */
-        axes[i] = ibroadcast;
-        --ibroadcast;
-      }
-    } else {
-      /* It's a labeled dimension, find the matching one */
-      char* match = reinterpret_cast<char*>(memchr(labels, label, ndim));
-      /* If the op doesn't have the label, broadcast it */
-      if (match == nullptr) {
-        axes[i] = -1;
-      } else {
-        /* Otherwise use it */
-        axes[i] = match - labels;
-      }
-    }
-  }
-  return 0;
-}
-
-/*!
- * \brief Count SubString.
- * \param str the object string
- * \param sub the pattern string
- *
- * \return number of substring
- */
-inline int CountSubstring(const std::string& str, const std::string& sub) {
-  int count = 0;
-  std::string::size_type pos = 0;
-  while ((pos = str.find(sub, pos)) != std::string::npos) {
-    ++count;
-    pos += sub.length();
-  }
-  return count;
-}
-
-/*!
- * \brief Transfer string to.
- * \param str input string.
- *
- * \return bitset.
- */
-inline std::bitset<LABELRANGE> Str2Set(const std::string& str) {
-  std::bitset<LABELRANGE> ret;
-  for (const char& c : str) {
-    ret.set(static_cast<int>(c));
-  }
-  return ret;
-}
-
-/*!
- * \brief Split str according to substring.
- * \param str input string.
- * \param sub the split pattern string.
- *
- * \return vector contains the splited substring.
- */
-inline std::vector<std::string> Split(const std::string& str, const std::string& sub) {
-  std::string::size_type pos = 0;
-  std::string::size_type start = 0;
-  std::vector<std::string> ret;
-  while ((pos = str.find(sub, start)) != std::string::npos) {
-    ret.push_back(str.substr(start, pos - start));
-    start = pos + sub.length();
-  }
-  ret.push_back(str.substr(start));
-  return ret;
-}
-
-/*!
- * \brief Parse the input subscripts into a vector of strings.
- * \param subscripts input subscripts.
- * \param operands operand tensors.
- *
- * \return vector of strings, vector[0] represents the input part, vector[1] represents the output.
- * if no output, the vector[1] is NULL.
- * "ab, bc -> ac" => ["ab,bc", "ac"]
- */
-inline std::tuple<std::string, std::string> ParseEinsumInput(
-    std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
-  const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
-  std::bitset<LABELRANGE> einsum_symbols_set;
-  for (const char& c : einsum_symbols) {
-    einsum_symbols_set.set(c);
-  }
-
-  CHECK_NE(operands.size(), 0U) << "No input operands";
-
-  auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' ');
-  subscripts.erase(end_pos, subscripts.end());
-
-  // Ensure all characters are valid
-  for (const char& c : subscripts) {
-    if (c == '.' || c == ',' || c == '-' || c == '>') {
-      continue;
-    }
-    CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-  }
-
-  // Check for proper "->"
-  if (subscripts.find('-') != std::string::npos || subscripts.find('>') != std::string::npos) {
-    bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 ||
-                    std::count(subscripts.begin(), subscripts.end(), '>') > 1);
-    CHECK(!invalid && CountSubstring(subscripts, "->") == 1)
-        << "Subscripts can only contain one '->'.";
-  }
-
-  // Parse ellipses
-  if (subscripts.find('.') != std::string::npos) {
-    std::string used = subscripts;
-    used.erase(
-        std::remove_if(used.begin(), used.end(),
-                       [](const char& c) { return c == '.' || c == ',' || c == '-' || c == '>'; }),
-        used.end());
-
-    std::bitset<LABELRANGE> used_set = Str2Set(used);
-    std::string ellipse_inds = "";
-    for (const char& c : einsum_symbols) {
-      if (!used_set.test(static_cast<int>(c))) {
-        ellipse_inds.append(1, c);
-      }
-    }
-    int longest = 0;
-    std::string input_tmp, output_sub;
-    std::vector<std::string> split_subscripts;
-    bool out_sub;
-
-    if (subscripts.find("->") != std::string::npos) {
-      std::vector<std::string> tmp = Split(subscripts, "->");
-      input_tmp = tmp[0];
-      output_sub = tmp[1];
-      split_subscripts = Split(input_tmp, ",");
-      out_sub = true;
-    } else {
-      split_subscripts = Split(subscripts, ",");
-      out_sub = false;
-    }
-
-    size_t size_split_subscripts = split_subscripts.size();
-    subscripts = "";
-    for (size_t i = 0; i < size_split_subscripts; ++i) {
-      const std::string& sub = split_subscripts[i];
-      if (sub.find('.') != std::string::npos) {
-        CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) << "Invalid Ellipses";
-        CHECK_EQ(CountSubstring(sub, "..."), 1) << "Invalid Ellipses";
-
-        // Take into account numerical values
-        int ellipse_count = 0;
-        if (operands[i].size() == 0) {
-          ellipse_count = 0;
-        } else {
-          ellipse_count = std::max(operands[i].size(), static_cast<size_t>(1));
-          ellipse_count -= sub.length() - 3;
-        }
-
-        if (ellipse_count > longest) {
-          longest = ellipse_count;
-        }
-
-        CHECK_GE(ellipse_count, 0) << "Ellipses lengths do not match.";
-        if (ellipse_count == 0) {
-          split_subscripts[i].erase(sub.find("..."), 3);
-        } else {
-          std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count);
-          split_subscripts[i].replace(sub.find("..."), 3, rep_inds);
-        }
-      }
-      subscripts += split_subscripts[i];
-      if (i + 1 < size_split_subscripts) {
-        subscripts += ",";
-      }
-    }
-    std::string out_ellipse;
-    if (longest == 0) {
-      out_ellipse = "";
-    } else {
-      out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
-    }
-
-    if (out_sub) {
-      output_sub.replace(output_sub.find("..."), 3, out_ellipse);
-      subscripts += "->" + output_sub;
-    } else {
-      // Special care for outputless ellipses
-      std::bitset<LABELRANGE> out_ellipse_set = Str2Set(out_ellipse);
-      std::string tmp_subscripts = subscripts, output_subscript = "";
-      size_t len_tmp_subscripts = tmp_subscripts.length();
-      std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
-      for (size_t i = 0; i < len_tmp_subscripts; ++i) {
-        const char& c = tmp_subscripts[i];
-        if (c == ',') {
-          continue;
-        }
-        CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-        if ((i == 0 || tmp_subscripts[i - 1] != c) &&
-            (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
-            !out_ellipse_set.test(c)) {
-          output_subscript.append(1, c);
-        }
-      }
-      subscripts += "->" + out_ellipse + output_subscript;
-    }
-  }
-
-  // Build output string if does not exist
-  std::tuple<std::string, std::string> ret;
-  if (subscripts.find("->") != std::string::npos) {
-    std::vector<std::string> tmp(2);
-    tmp = Split(subscripts, "->");
-    ret = std::make_tuple(tmp[0], tmp[1]);
-  } else {
-    std::string first = subscripts;
-    std::string second = "";
-    // Build output subscripts
-    std::string tmp_subscripts = subscripts;
-    size_t len_tmp_subscripts = tmp_subscripts.length();
-    std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
-    for (size_t i = 0; i < len_tmp_subscripts; ++i) {
-      const char& c = tmp_subscripts[i];
-      if (c == ',') {
-        continue;
-      }
-      CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-      if ((i == 0 || tmp_subscripts[i - 1] != c) &&
-          (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
-        second.append(1, c);
-      }
-    }
-    ret = std::make_tuple(first, second);
-  }
-
-  // Make sure output subscripts are in the input
-  std::bitset<LABELRANGE> input_subscripts_set = Str2Set(std::get<0>(ret));
-  for (const char& c : std::get<1>(ret)) {
-    CHECK(input_subscripts_set.test(c))
-        << "Output character " << c << " did not appear in the input";
-  }
-
-  // Make sure number operands is equivalent to the number of terms
-  CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') + 1, operands.size())
-      << "Number of einsum subscripts must be equal to the "
-      << "number of operands.";
-
-  return ret;
-}
-
 /*!
  * \brief Compute the shape of the output.
  * \param subscripts input subscripts.
  * \param operands operand tensors.
  *
  * \return the shape of the output.
  */
-inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
-                                        const std::vector<Array<PrimExpr>>& operands) {
-  // Parsing
-  std::tuple<std::string, std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);
-
-  // Build a few useful list and sets
-  std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts), ",");
-  size_t isize = input_list.size();
-
-  // Get length of each unique dimension and ensure all dimensions are correct
-  int dimension_dict[LABELRANGE];
-  memset(dimension_dict, -1, sizeof(dimension_dict));
-  for (size_t i = 0; i < isize; ++i) {
-    const std::string& term = input_list[i];
-    const Array<PrimExpr>& sh = operands[i];
-    CHECK_EQ(sh.size(), term.length())
-        << "Einstein sum subscript " << input_list[i] << " does not contain the "
-        << "correct number of indices for operand " << i << ".";
-    size_t len_term = term.length();
-    for (size_t j = 0; j < len_term; ++j) {
-      int64_t dim = GetConstInt(sh[j]);
-      const char& c = term[j];
-
-      if (dimension_dict[static_cast<int>(c)] != -1) {
-        // For broadcasting cases we always want the largest dim size
-        if (dimension_dict[static_cast<int>(c)] == 1) {
-          dimension_dict[static_cast<int>(c)] = dim;
-        }
-        CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
-            << "Size of label '" << c << "' for operand  " << i << " ("
-            << dimension_dict[static_cast<int>(c)] << ") does not match previous terms (" << dim
-            << ").";
-      } else {
-        dimension_dict[static_cast<int>(c)] = dim;
-      }
-    }
-  }
-
-  // Get oshape
-  const std::string& output_str = std::get<1>(parsed_subscripts);
-  size_t odim = output_str.size();
-  Array<PrimExpr> oshape(odim, -1);
-  for (size_t i = 0; i < odim; ++i) {
-    oshape.Set(i, dimension_dict[static_cast<int>(output_str[i])]);
-  }
-  // Neglecting oshape assign check temporally
-  return oshape;
-}
+Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,

Review Comment:
   Just curious: what does the “Numpy” prefix stand for?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 merged pull request #12913: [TOPI] Implement Einsum with reduction axes

Posted by GitBox <gi...@apache.org>.
vinx13 merged PR #12913:
URL: https://github.com/apache/tvm/pull/12913


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #12913: [TOPI] Implement Einsum with reduction axes

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #12913:
URL: https://github.com/apache/tvm/pull/12913#discussion_r981687802


##########
src/topi/transform.cc:
##########
@@ -173,10 +173,6 @@ TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv)
   }
 });
 
-TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = einsum(args[0], args[1]);
-});
-

Review Comment:
   it's moved to `einsum.cc` file



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #12913: [TOPI] Implement Einsum with reduction axes

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on code in PR #12913:
URL: https://github.com/apache/tvm/pull/12913#discussion_r981696199


##########
include/tvm/topi/einsum.h:
##########
@@ -49,623 +49,15 @@ namespace topi {
 using namespace tvm::te;
 using namespace topi::detail;
 
-/*!
- * \brief Compute the stride of the given shape.
- *
- * \param shape for the operation.
- *
- * \return the stride of the shape.
- */
-inline Array<PrimExpr> GetStride(const Array<PrimExpr> shape) {
-  size_t ndim = shape.size();
-  int prod = 1;
-  Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
-  for (int i = ndim - 1; i >= 0; i--) {
-    stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
-    prod = prod * GetConstInt(shape[i]);
-  }
-  return stride;
-}
-
-/*!
- * \brief Pad the shape with 1.
- *
- * \param shape the input shape to be padded
- * \param odim the padding size of the objective shape.
- *
- * \return the padded shape.
- */
-inline Array<PrimExpr> Pad(const Array<PrimExpr> shape, int odim) {
-  int ndim = shape.size();
-  CHECK_GE(odim, ndim);
-  Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
-  for (int idim = 0; idim < ndim; ++idim) {
-    ret.Set(idim, shape[idim]);
-  }
-  return ret;
-}
-
-/*!
- * \brief Parse the subscripts for one operand into an output of 'ndim' labels.
- *
- * \param subscripts the subscripts for to be parsed.
- * \param length subscripts[0: length] represents the current operand.
- * \param ndim the ndim of current operand.
- * \param iop the index of the operand.
- * \param op_labels the parsing result.
- *        For Example:
- *           subscripts="abbcbc",  ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2].
- *           subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99].
- * \param label_counts Count the number the label appears.
- * \param min_label Save the minimal label according to ASCII.
- * \param max_label Save the maximal label according to ASCII.
- *
- * \return 0.
- */
-inline int ParseOperandSubscripts(const char* subscripts, int length, int ndim, int iop,
-                                  char* op_labels, char* label_counts, int* min_label,
-                                  int* max_label) {
-  int i;
-  int idim = 0;
-  int ellipsis = -1;
-
-  /* Process all labels for this operand */
-  for (i = 0; i < length; ++i) {
-    int label = subscripts[i];
-
-    /* A proper label for an axis. */
-    if (label > 0 && isalpha(label)) {
-      /* Check we don't exceed the operator dimensions. */
-      CHECK(idim < ndim) << "einstein sum subscripts string contains "
-                         << "too many subscripts for operand " << iop;
-
-      op_labels[idim++] = label;
-      if (label < *min_label) {
-        *min_label = label;
-      }
-      if (label > *max_label) {
-        *max_label = label;
-      }
-      label_counts[label]++;
-    } else if (label == '.') {
-      /* The beginning of the ellipsis. */
-      /* Check it's a proper ellipsis. */
-      CHECK(
-          !(ellipsis != -1 || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
-          << "einstein sum subscripts string contains a "
-          << "'.' that is not part of an ellipsis ('...') "
-          << "in operand " << iop;
-
-      ellipsis = idim;
-    } else {
-      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
-                          << "' in einstein sum "
-                          << "subscripts string, subscripts must "
-                          << "be letters";
-    }
-  }
-
-  /* No ellipsis found, labels must match dimensions exactly. */
-  if (ellipsis == -1) {
-    CHECK(idim == ndim) << "operand has more dimensions than subscripts "
-                        << "given in einstein sum, but no '...' ellipsis "
-                        << "provided to broadcast the extra dimensions.";
-  } else if (idim < ndim) {
-    /* Ellipsis found, may have to add broadcast dimensions. */
-    /* Move labels after ellipsis to the end. */
-    for (i = 0; i < idim - ellipsis; ++i) {
-      op_labels[ndim - i - 1] = op_labels[idim - i - 1];
-    }
-    /* Set all broadcast dimensions to zero. */
-    for (i = 0; i < ndim - idim; ++i) {
-      op_labels[ellipsis + i] = 0;
-    }
-  }
-
-  /*
-   * Find any labels duplicated for this operand, and turn them
-   * into negative offsets to the axis to merge with.
-   *
-   * In C, the char type may be signed or unsigned, but with
-   * twos complement arithmetic the char is ok either way here, and
-   * later where it matters the char is cast to a signed char.
-   */
-  for (idim = 0; idim < ndim - 1; ++idim) {
-    int label = op_labels[idim];
-    /* If it is a proper label, find any duplicates of it. */
-    if (label > 0) {
-      /* Search for the next matching label. */
-      char* next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
-
-      while (next != nullptr) {
-        /* The offset from next to op_labels[idim] (negative). */
-        *next = static_cast<char>((op_labels + idim) - next);
-        /* Search for the next matching label. */
-        next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
-      }
-    }
-  }
-  return 0;
-}
-
-/*!
- * \brief Parse the subscripts for the output into an output that includes 'ndim_broadcast'
- *        unlabeled dimensions.
- *
- * \param subscripts the subscripts for to be parsed.
- * \param length subscripts[0: length] represents the output operand.
- * \param ndim_broadcast the broadcast dimension number.
- * \param label_counts Count the number the label appears.
- * \param out_labels similar to the op_labels in ParseOperandSubscripts, for each
- *        dimension, the ASCII code of the corresponding label. zero for the broadcasting dim.
- *
- * \return the total number of output dimensions or -1 if there is an error.
- */
-inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_broadcast,
-                                 const char* label_counts, char* out_labels) {
-  int i, bdim;
-  int ndim = 0;
-  int ellipsis = 0;
-
-  /* Process all the output labels. */
-  for (i = 0; i < length; ++i) {
-    int label = subscripts[i];
-
-    /* A proper label for an axis. */
-    if (label > 0 && isalpha(label)) {
-      /* Check that it doesn't occur again. */
-      CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
-          << "einstein sum subscripts string includes "
-          << "output subscript '" << static_cast<char>(label) << "' multiple times";
-
-      /* Check that it was used in the inputs. */
-      CHECK(label_counts[label] != 0)
-          << "einstein sum subscripts string included "
-          << "output subscript '" << static_cast<char>(label) << "' which never appeared "
-          << "in an input";
-
-      /* Check that there is room in out_labels for this label. */
-      CHECK(ndim < NPY_MAXDIMS) << "einstein sum subscripts string contains "
-                                << "too many subscripts in the output";
-
-      out_labels[ndim++] = label;
-    } else if (label == '.') {
-      /* The beginning of the ellipsis. */
-      /* Check it is a proper ellipsis. */
-      CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
-          << "einstein sum subscripts string "
-          << "contains a '.' that is not part of "
-          << "an ellipsis ('...') in the output";
-
-      /* Check there is room in out_labels for broadcast dims. */
-      CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) << "einstein sum subscripts string contains "
-                                                  << "too many subscripts in the output";
-
-      ellipsis = 1;
-      for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
-        out_labels[ndim++] = 0;
-      }
-    } else {
-      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
-                          << "' in einstein sum "
-                          << "subscripts string, subscripts must "
-                          << "be letters";
-    }
-  }
-
-  /* If no ellipsis was found there should be no broadcast dimensions. */
-  CHECK(!(!ellipsis && ndim_broadcast > 0)) << "output has more dimensions than subscripts "
-                                            << "given in einstein sum, but no '...' ellipsis "
-                                            << "provided to broadcast the extra dimensions.";
-
-  return ndim;
-}
-
-/*!
- * \brief If any dimensions are combined, create a view that combines them.
- *        Shows in newshape and newstride.
- *
- * \param op the operand tensor.
- * \param iop the index of the operand.
- * \param labels the op_labels fot the operand. Like [97, 98, -2] for "aba".
- * \param newshape The combined shape.
- * \param newstride The combined stride.
- *
- * For example:
- *  "aba -> ab",              shape = [2,3,2] stride = [6,2,1]
- *  op_labels = [97, 98, -2], newshape = [2,3], newstride = [7,2]
- */
-inline void GetCombinedDimsView(const Tensor& op, int iop, char* labels, Array<PrimExpr>* newshape,
-                                Array<PrimExpr>* newstride) {
-  int idim, ndim, icombine, combineoffset;
-  int icombinemap[NPY_MAXDIMS];
-  int newdim;
-
-  Array<PrimExpr> shape = op->shape;
-  Array<PrimExpr> stride = GetStride(shape);
-  ndim = op.ndim();
-  newdim = newshape->size();
-
-  /* Initialize the dimensions and strides to zero */
-  for (idim = 0; idim < newdim; ++idim) {
-    newshape->Set(idim, 0);
-    newstride->Set(idim, 0);
-  }
-
-  /* Copy the dimensions and strides, except when collapsing */
-  icombine = 0;
-  for (idim = 0; idim < ndim; ++idim) {
-    /*
-     * The char type may be either signed or unsigned, we
-     * need it to be signed here.
-     */
-    int label = (signed char)labels[idim];
-    /* If this label says to merge axes, get the actual label */
-    if (label < 0) {
-      combineoffset = label;
-      label = labels[idim + label];
-    } else {
-      combineoffset = 0;
-      if (icombine != idim) {
-        labels[icombine] = labels[idim];
-      }
-      icombinemap[idim] = icombine;
-    }
-    /* If the label is 0, it's an unlabeled broadcast dimension */
-    if (label == 0) {
-      newshape->Set(icombine, shape[idim]);
-      newstride->Set(icombine, stride[idim]);
-    } else {
-      /* Update the combined axis dimensions and strides */
-      int i = icombinemap[idim + combineoffset];
-      CHECK(!((combineoffset < 0) &&
-              GetConstInt((*newshape)[i] != 0 && (*newshape)[i] != shape[idim])))
-          << "dimensions in operand " << iop << " for collapsing index '" << label
-          << "' don't match (" << GetConstInt((*newshape)[i]) << " != " << shape[idim] << ")";
-      newshape->Set(i, shape[idim]);
-      newstride->Set(i, (*newstride)[i] + stride[idim]);
-    }
-
-    /* If the label didn't say to combine axes, increment dest i */
-    if (combineoffset == 0) {
-      icombine++;
-    }
-  }
-}
-
-/*!
- * \brief Prepare the operand axes to match each stride or shape pair.
- *
- * \param ndim the ndim of the operand tensor.
- * \param iop the index of the operand.
- * \param labels the op_labels fot the operand. [97, 98, -1, 99, -3, -2] for "abbcbc".
- * \param axes The matched axes to be calculated.
- * \param ndim_iter the dimension of iterating. Subscripts "ab, bc -> ac" ndim_iter = 3.
- * \param iter_labels output_labels with the iterating label. ['a', 'c', 'b'] for the case above.
- */
-inline static int PrepareOpAxes(int ndim, int iop, char* labels, int* axes, int ndim_iter,
-                                char* iter_labels) {
-  int i, label, ibroadcast;
-
-  ibroadcast = ndim - 1;
-  for (i = ndim_iter - 1; i >= 0; --i) {
-    label = iter_labels[i];
-    /*
-     * If it's an unlabeled broadcast dimension, choose
-     * the next broadcast dimension from the operand.
-     */
-    if (label == 0) {
-      while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
-        --ibroadcast;
-      }
-      /*
-       * If we used up all the operand broadcast dimensions,
-       * extend it with a "newaxis"
-       */
-      if (ibroadcast < 0) {
-        axes[i] = -1;
-      } else {
-        /* Otherwise map to the broadcast axis */
-        axes[i] = ibroadcast;
-        --ibroadcast;
-      }
-    } else {
-      /* It's a labeled dimension, find the matching one */
-      char* match = reinterpret_cast<char*>(memchr(labels, label, ndim));
-      /* If the op doesn't have the label, broadcast it */
-      if (match == nullptr) {
-        axes[i] = -1;
-      } else {
-        /* Otherwise use it */
-        axes[i] = match - labels;
-      }
-    }
-  }
-  return 0;
-}
-
-/*!
- * \brief Count SubString.
- * \param str the object string
- * \param sub the pattern string
- *
- * \return number of substring
- */
-inline int CountSubstring(const std::string& str, const std::string& sub) {
-  int count = 0;
-  std::string::size_type pos = 0;
-  while ((pos = str.find(sub, pos)) != std::string::npos) {
-    ++count;
-    pos += sub.length();
-  }
-  return count;
-}
-
-/*!
- * \brief Transfer string to.
- * \param str input string.
- *
- * \return bitset.
- */
-inline std::bitset<LABELRANGE> Str2Set(const std::string& str) {
-  std::bitset<LABELRANGE> ret;
-  for (const char& c : str) {
-    ret.set(static_cast<int>(c));
-  }
-  return ret;
-}
-
-/*!
- * \brief Split str according to substring.
- * \param str input string.
- * \param sub the split pattern string.
- *
- * \return vector contains the splited substring.
- */
-inline std::vector<std::string> Split(const std::string& str, const std::string& sub) {
-  std::string::size_type pos = 0;
-  std::string::size_type start = 0;
-  std::vector<std::string> ret;
-  while ((pos = str.find(sub, start)) != std::string::npos) {
-    ret.push_back(str.substr(start, pos - start));
-    start = pos + sub.length();
-  }
-  ret.push_back(str.substr(start));
-  return ret;
-}
-
-/*!
- * \brief Parse the input subscripts into a vector of strings.
- * \param subscripts input subscripts.
- * \param operands operand tensors.
- *
- * \return vector of strings, vector[0] represents the input part, vector[1] represents the output.
- * if no output, the vector[1] is NULL.
- * "ab, bc -> ac" => ["ab,bc", "ac"]
- */
-inline std::tuple<std::string, std::string> ParseEinsumInput(
-    std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
-  const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
-  std::bitset<LABELRANGE> einsum_symbols_set;
-  for (const char& c : einsum_symbols) {
-    einsum_symbols_set.set(c);
-  }
-
-  CHECK_NE(operands.size(), 0U) << "No input operands";
-
-  auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' ');
-  subscripts.erase(end_pos, subscripts.end());
-
-  // Ensure all characters are valid
-  for (const char& c : subscripts) {
-    if (c == '.' || c == ',' || c == '-' || c == '>') {
-      continue;
-    }
-    CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-  }
-
-  // Check for proper "->"
-  if (subscripts.find('-') != std::string::npos || subscripts.find('>') != std::string::npos) {
-    bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 ||
-                    std::count(subscripts.begin(), subscripts.end(), '>') > 1);
-    CHECK(!invalid && CountSubstring(subscripts, "->") == 1)
-        << "Subscripts can only contain one '->'.";
-  }
-
-  // Parse ellipses
-  if (subscripts.find('.') != std::string::npos) {
-    std::string used = subscripts;
-    used.erase(
-        std::remove_if(used.begin(), used.end(),
-                       [](const char& c) { return c == '.' || c == ',' || c == '-' || c == '>'; }),
-        used.end());
-
-    std::bitset<LABELRANGE> used_set = Str2Set(used);
-    std::string ellipse_inds = "";
-    for (const char& c : einsum_symbols) {
-      if (!used_set.test(static_cast<int>(c))) {
-        ellipse_inds.append(1, c);
-      }
-    }
-    int longest = 0;
-    std::string input_tmp, output_sub;
-    std::vector<std::string> split_subscripts;
-    bool out_sub;
-
-    if (subscripts.find("->") != std::string::npos) {
-      std::vector<std::string> tmp = Split(subscripts, "->");
-      input_tmp = tmp[0];
-      output_sub = tmp[1];
-      split_subscripts = Split(input_tmp, ",");
-      out_sub = true;
-    } else {
-      split_subscripts = Split(subscripts, ",");
-      out_sub = false;
-    }
-
-    size_t size_split_subscripts = split_subscripts.size();
-    subscripts = "";
-    for (size_t i = 0; i < size_split_subscripts; ++i) {
-      const std::string& sub = split_subscripts[i];
-      if (sub.find('.') != std::string::npos) {
-        CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) << "Invalid Ellipses";
-        CHECK_EQ(CountSubstring(sub, "..."), 1) << "Invalid Ellipses";
-
-        // Take into account numerical values
-        int ellipse_count = 0;
-        if (operands[i].size() == 0) {
-          ellipse_count = 0;
-        } else {
-          ellipse_count = std::max(operands[i].size(), static_cast<size_t>(1));
-          ellipse_count -= sub.length() - 3;
-        }
-
-        if (ellipse_count > longest) {
-          longest = ellipse_count;
-        }
-
-        CHECK_GE(ellipse_count, 0) << "Ellipses lengths do not match.";
-        if (ellipse_count == 0) {
-          split_subscripts[i].erase(sub.find("..."), 3);
-        } else {
-          std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count);
-          split_subscripts[i].replace(sub.find("..."), 3, rep_inds);
-        }
-      }
-      subscripts += split_subscripts[i];
-      if (i + 1 < size_split_subscripts) {
-        subscripts += ",";
-      }
-    }
-    std::string out_ellipse;
-    if (longest == 0) {
-      out_ellipse = "";
-    } else {
-      out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
-    }
-
-    if (out_sub) {
-      output_sub.replace(output_sub.find("..."), 3, out_ellipse);
-      subscripts += "->" + output_sub;
-    } else {
-      // Special care for outputless ellipses
-      std::bitset<LABELRANGE> out_ellipse_set = Str2Set(out_ellipse);
-      std::string tmp_subscripts = subscripts, output_subscript = "";
-      size_t len_tmp_subscripts = tmp_subscripts.length();
-      std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
-      for (size_t i = 0; i < len_tmp_subscripts; ++i) {
-        const char& c = tmp_subscripts[i];
-        if (c == ',') {
-          continue;
-        }
-        CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-        if ((i == 0 || tmp_subscripts[i - 1] != c) &&
-            (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
-            !out_ellipse_set.test(c)) {
-          output_subscript.append(1, c);
-        }
-      }
-      subscripts += "->" + out_ellipse + output_subscript;
-    }
-  }
-
-  // Build output string if does not exist
-  std::tuple<std::string, std::string> ret;
-  if (subscripts.find("->") != std::string::npos) {
-    std::vector<std::string> tmp(2);
-    tmp = Split(subscripts, "->");
-    ret = std::make_tuple(tmp[0], tmp[1]);
-  } else {
-    std::string first = subscripts;
-    std::string second = "";
-    // Build output subscripts
-    std::string tmp_subscripts = subscripts;
-    size_t len_tmp_subscripts = tmp_subscripts.length();
-    std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
-    for (size_t i = 0; i < len_tmp_subscripts; ++i) {
-      const char& c = tmp_subscripts[i];
-      if (c == ',') {
-        continue;
-      }
-      CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
-      if ((i == 0 || tmp_subscripts[i - 1] != c) &&
-          (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
-        second.append(1, c);
-      }
-    }
-    ret = std::make_tuple(first, second);
-  }
-
-  // Make sure output subscripts are in the input
-  std::bitset<LABELRANGE> input_subscripts_set = Str2Set(std::get<0>(ret));
-  for (const char& c : std::get<1>(ret)) {
-    CHECK(input_subscripts_set.test(c))
-        << "Output character " << c << " did not appear in the input";
-  }
-
-  // Make sure number operands is equivalent to the number of terms
-  CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') + 1, operands.size())
-      << "Number of einsum subscripts must be equal to the "
-      << "number of operands.";
-
-  return ret;
-}
-
 /*!
  * \brief Compute the shape of the output.
  * \param subscripts input subscripts.
  * \param operands operand tensors.
  *
  * \return the shape of the output.
  */
-inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
-                                        const std::vector<Array<PrimExpr>>& operands) {
-  // Parsing
-  std::tuple<std::string, std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);
-
-  // Build a few useful list and sets
-  std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts), ",");
-  size_t isize = input_list.size();
-
-  // Get length of each unique dimension and ensure all dimensions are correct
-  int dimension_dict[LABELRANGE];
-  memset(dimension_dict, -1, sizeof(dimension_dict));
-  for (size_t i = 0; i < isize; ++i) {
-    const std::string& term = input_list[i];
-    const Array<PrimExpr>& sh = operands[i];
-    CHECK_EQ(sh.size(), term.length())
-        << "Einstein sum subscript " << input_list[i] << " does not contain the "
-        << "correct number of indices for operand " << i << ".";
-    size_t len_term = term.length();
-    for (size_t j = 0; j < len_term; ++j) {
-      int64_t dim = GetConstInt(sh[j]);
-      const char& c = term[j];
-
-      if (dimension_dict[static_cast<int>(c)] != -1) {
-        // For broadcasting cases we always want the largest dim size
-        if (dimension_dict[static_cast<int>(c)] == 1) {
-          dimension_dict[static_cast<int>(c)] = dim;
-        }
-        CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
-            << "Size of label '" << c << "' for operand  " << i << " ("
-            << dimension_dict[static_cast<int>(c)] << ") does not match previous terms (" << dim
-            << ").";
-      } else {
-        dimension_dict[static_cast<int>(c)] = dim;
-      }
-    }
-  }
-
-  // Get oshape
-  const std::string& output_str = std::get<1>(parsed_subscripts);
-  size_t odim = output_str.size();
-  Array<PrimExpr> oshape(odim, -1);
-  for (size_t i = 0; i < odim; ++i) {
-    oshape.Set(i, dimension_dict[static_cast<int>(output_str[i])]);
-  }
-  // Neglecting oshape assign check temporally
-  return oshape;
-}
+Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,

Review Comment:
   Sounds good. The function is just doing pure shape inference... Just some random thought, what about “InferEinsumShape”?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org