You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/02 03:08:48 UTC

[GitHub] [tvm] hanke580 commented on a change in pull request #6370: [TOPI] Add einsum operator

hanke580 commented on a change in pull request #6370:
URL: https://github.com/apache/tvm/pull/6370#discussion_r568292029



##########
File path: include/tvm/topi/einsum.h
##########
@@ -0,0 +1,937 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file topi/einsum.h
+ * \brief Einstein summation op
+ */
+#ifndef TVM_TOPI_EINSUM_H_
+#define TVM_TOPI_EINSUM_H_
+
+#include <tvm/te/operation.h>
+#include <tvm/tir/data_layout.h>
+#include <tvm/topi/detail/constant_utils.h>
+#include <tvm/topi/detail/ravel_unravel.h>
+#include <tvm/topi/detail/tensor_utils.h>
+#include <tvm/topi/tags.h>
+
+#include <algorithm>
+#include <bitset>
+#include <iterator>
+#include <string>
+#include <tuple>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace topi {
+
+using namespace tvm::te;
+using namespace topi::detail;
+
+/*!
+ * \brief Compute the stride of the given shape.
+ *
+ * \param shape for the operation.
+ *
+ * \return the stride of the shape.
+ */
+inline Array<PrimExpr> GetStride(const Array<PrimExpr> shape) {
+  size_t ndim = shape.size();
+  int prod = 1;
+  Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
+  for (int i = ndim - 1; i >= 0; i--) {
+    stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
+    prod = prod * GetConstInt(shape[i]);
+  }
+  return stride;
+}
+
+/*!
+ * \brief Pad the shape with 1.
+ *
+ * \param shape the input shape to be padded
+ * \param odim the padding size of the objective shape.
+ *
+ * \return the padded shape.
+ */
+inline Array<PrimExpr> Pad(const Array<PrimExpr> shape, int odim) {
+  int ndim = shape.size();
+  CHECK_GE(odim, ndim);
+  Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
+  for (int idim = 0; idim < ndim; ++idim) {
+    ret.Set(idim, shape[idim]);
+  }
+  return ret;
+}
+
+/*!
+ * \brief Parse the subscripts for one operand into an output of 'ndim' labels.
+ *
+ * \param subscripts the subscripts for to be parsed.
+ * \param length subscripts[0: length] represents the current operand.
+ * \param ndim the ndim of current operand.
+ * \param iop the index of the operand.
+ * \param op_labels the parsing result.
+ *        For Example:
+ *           subscripts="abbcbc",  ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2].
+ *           subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99].
+ * \param label_counts Count the number the label appears.
+ * \param min_label Save the minimal label according to ASCII.
+ * \param max_label Save the maximal label according to ASCII.
+ *
+ * \return 0.
+ */
+inline int ParseOperandSubscripts(const char* subscripts, int length, int ndim, int iop,
+                                  char* op_labels, char* label_counts, int* min_label,
+                                  int* max_label) {
+  int i;
+  int idim = 0;
+  int ellipsis = -1;
+
+  /* Process all labels for this operand */
+  for (i = 0; i < length; ++i) {
+    int label = subscripts[i];
+
+    /* A proper label for an axis. */
+    if (label > 0 && isalpha(label)) {
+      /* Check we don't exceed the operator dimensions. */
+      CHECK(idim < ndim) << "einstein sum subscripts string contains "
+                         << "too many subscripts for operand " << iop;
+
+      op_labels[idim++] = label;
+      if (label < *min_label) {
+        *min_label = label;
+      }
+      if (label > *max_label) {
+        *max_label = label;
+      }
+      label_counts[label]++;
+    } else if (label == '.') {
+      /* The beginning of the ellipsis. */
+      /* Check it's a proper ellipsis. */
+      CHECK(
+          !(ellipsis != -1 || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
+          << "einstein sum subscripts string contains a "
+          << "'.' that is not part of an ellipsis ('...') "
+          << "in operand " << iop;
+
+      ellipsis = idim;
+    } else {
+      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
+                          << "' in einstein sum "
+                          << "subscripts string, subscripts must "
+                          << "be letters";
+    }
+  }
+
+  /* No ellipsis found, labels must match dimensions exactly. */
+  if (ellipsis == -1) {
+    CHECK(idim == ndim) << "operand has more dimensions than subscripts "
+                        << "given in einstein sum, but no '...' ellipsis "
+                        << "provided to broadcast the extra dimensions.";
+  } else if (idim < ndim) {
+    /* Ellipsis found, may have to add broadcast dimensions. */
+    /* Move labels after ellipsis to the end. */
+    for (i = 0; i < idim - ellipsis; ++i) {
+      op_labels[ndim - i - 1] = op_labels[idim - i - 1];
+    }
+    /* Set all broadcast dimensions to zero. */
+    for (i = 0; i < ndim - idim; ++i) {
+      op_labels[ellipsis + i] = 0;
+    }
+  }
+
+  /*
+   * Find any labels duplicated for this operand, and turn them
+   * into negative offsets to the axis to merge with.
+   *
+   * In C, the char type may be signed or unsigned, but with
+   * twos complement arithmetic the char is ok either way here, and
+   * later where it matters the char is cast to a signed char.
+   */
+  for (idim = 0; idim < ndim - 1; ++idim) {
+    int label = op_labels[idim];
+    /* If it is a proper label, find any duplicates of it. */
+    if (label > 0) {
+      /* Search for the next matching label. */
+      char* next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
+
+      while (next != nullptr) {
+        /* The offset from next to op_labels[idim] (negative). */
+        *next = static_cast<char>((op_labels + idim) - next);
+        /* Search for the next matching label. */
+        next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
+      }
+    }
+  }
+  return 0;
+}
+
+/*!
+ * \brief Parse the subscripts for the output into an output that includes 'ndim_broadcast'
+ *        unlabeled dimensions.
+ *
+ * \param subscripts the subscripts for to be parsed.
+ * \param length subscripts[0: length] represents the output operand.
+ * \param ndim_broadcast the broadcast dimension number.
+ * \param label_counts Count the number the label appears.
+ * \param out_labels similar to the op_labels in ParseOperandSubscripts, for each
+ *        dimension, the ASCII code of the corresponding label. zero for the broadcasting dim.
+ *
+ * \return the total number of output dimensions or -1 if there is an error.
+ */
+inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_broadcast,
+                                 const char* label_counts, char* out_labels) {
+  int i, bdim;
+  int ndim = 0;
+  int ellipsis = 0;
+
+  /* Process all the output labels. */
+  for (i = 0; i < length; ++i) {
+    int label = subscripts[i];
+
+    /* A proper label for an axis. */
+    if (label > 0 && isalpha(label)) {
+      /* Check that it doesn't occur again. */
+      CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
+          << "einstein sum subscripts string includes "
+          << "output subscript '" << static_cast<char>(label) << "' multiple times";
+
+      /* Check that it was used in the inputs. */
+      CHECK(label_counts[label] != 0)
+          << "einstein sum subscripts string included "
+          << "output subscript '" << static_cast<char>(label) << "' which never appeared "
+          << "in an input";
+
+      /* Check that there is room in out_labels for this label. */
+      CHECK(ndim < 16) << "einstein sum subscripts string contains "
+                       << "too many subscripts in the output";
+
+      out_labels[ndim++] = label;
+    } else if (label == '.') {
+      /* The beginning of the ellipsis. */
+      /* Check it is a proper ellipsis. */
+      CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
+          << "einstein sum subscripts string "
+          << "contains a '.' that is not part of "
+          << "an ellipsis ('...') in the output";
+
+      /* Check there is room in out_labels for broadcast dims. */
+      CHECK(ndim + ndim_broadcast <= 16) << "einstein sum subscripts string contains "
+                                         << "too many subscripts in the output";
+
+      ellipsis = 1;
+      for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
+        out_labels[ndim++] = 0;
+      }
+    } else {
+      CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
+                          << "' in einstein sum "
+                          << "subscripts string, subscripts must "
+                          << "be letters";
+    }
+  }
+
+  /* If no ellipsis was found there should be no broadcast dimensions. */
+  CHECK(!(!ellipsis && ndim_broadcast > 0)) << "output has more dimensions than subscripts "
+                                            << "given in einstein sum, but no '...' ellipsis "
+                                            << "provided to broadcast the extra dimensions.";
+
+  return ndim;
+}
+
+/*!
+ * \brief If any dimensions are combined, create a view that combines them.
+ *        Shows in newshape and newstride.
+ *
+ * \param op the operand tensor.
+ * \param iop the index of the operand.
+ * \param labels the op_labels fot the operand. Like [97, 98, -2] for "aba".
+ * \param newshape The combined shape.
+ * \param newstride The combined stride.
+ *
+ * For example:
+ *  "aba -> ab",              shape = [2,3,2] stride = [6,2,1]
+ *  op_labels = [97, 98, -2], newshape = [2,3], newstride = [7,2]
+ */
+inline void GetCombinedDimsView(const Tensor& op, int iop, char* labels, Array<PrimExpr>* newshape,
+                                Array<PrimExpr>* newstride) {
+  int idim, ndim, icombine, combineoffset;
+  int icombinemap[16];
+  int newdim;
+
+  Array<PrimExpr> shape = op->shape;
+  Array<PrimExpr> stride = GetStride(shape);
+  ndim = op.ndim();
+  newdim = newshape->size();
+
+  /* Initialize the dimensions and strides to zero */
+  for (idim = 0; idim < newdim; ++idim) {
+    newshape->Set(idim, 0);
+    newstride->Set(idim, 0);
+  }
+
+  /* Copy the dimensions and strides, except when collapsing */
+  icombine = 0;
+  for (idim = 0; idim < ndim; ++idim) {
+    /*
+     * The char type may be either signed or unsigned, we
+     * need it to be signed here.
+     */
+    int label = (signed char)labels[idim];
+    /* If this label says to merge axes, get the actual label */
+    if (label < 0) {
+      combineoffset = label;
+      label = labels[idim + label];
+    } else {
+      combineoffset = 0;
+      if (icombine != idim) {
+        labels[icombine] = labels[idim];
+      }
+      icombinemap[idim] = icombine;
+    }
+    /* If the label is 0, it's an unlabeled broadcast dimension */
+    if (label == 0) {
+      newshape->Set(icombine, shape[idim]);
+      newstride->Set(icombine, stride[idim]);
+    } else {
+      /* Update the combined axis dimensions and strides */
+      int i = icombinemap[idim + combineoffset];
+      CHECK(!((combineoffset < 0) &&
+              GetConstInt((*newshape)[i] != 0 && (*newshape)[i] != shape[idim])))
+          << "dimensions in operand " << iop << " for collapsing index '" << label
+          << "' don't match (" << GetConstInt((*newshape)[i]) << " != " << shape[idim] << ")";
+      newshape->Set(i, shape[idim]);
+      newstride->Set(i, (*newstride)[i] + stride[idim]);
+    }
+
+    /* If the label didn't say to combine axes, increment dest i */
+    if (combineoffset == 0) {
+      icombine++;
+    }
+  }
+}
+
+/*!
+ * \brief Prepare the operand axes to match each stride or shape pair.
+ *
+ * \param ndim the ndim of the operand tensor.
+ * \param iop the index of the operand.
+ * \param labels the op_labels fot the operand. [97, 98, -1, 99, -3, -2] for "abbcbc".
+ * \param axes The matched axes to be calculated.
+ * \param ndim_iter the dimension of iterating. Subscripts "ab, bc -> ac" ndim_iter = 3.
+ * \param iter_labels output_labels with the iterating label. ['a', 'c', 'b'] for the case above.
+ */
+inline static int PrepareOpAxes(int ndim, int iop, char* labels, int* axes, int ndim_iter,
+                                char* iter_labels) {
+  int i, label, ibroadcast;
+
+  ibroadcast = ndim - 1;
+  for (i = ndim_iter - 1; i >= 0; --i) {
+    label = iter_labels[i];
+    /*
+     * If it's an unlabeled broadcast dimension, choose
+     * the next broadcast dimension from the operand.
+     */
+    if (label == 0) {
+      while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
+        --ibroadcast;
+      }
+      /*
+       * If we used up all the operand broadcast dimensions,
+       * extend it with a "newaxis"
+       */
+      if (ibroadcast < 0) {
+        axes[i] = -1;
+      } else {
+        /* Otherwise map to the broadcast axis */
+        axes[i] = ibroadcast;
+        --ibroadcast;
+      }
+    } else {
+      /* It's a labeled dimension, find the matching one */
+      char* match = reinterpret_cast<char*>(memchr(labels, label, ndim));
+      /* If the op doesn't have the label, broadcast it */
+      if (match == nullptr) {
+        axes[i] = -1;
+      } else {
+        /* Otherwise use it */
+        axes[i] = match - labels;
+      }
+    }
+  }
+  return 0;
+}
+
+/*!
+ * \brief Count SubString.
+ * \param str the object string
+ * \param sub the pattern string
+ *
+ * \return number of substring
+ */
+inline int CountSubstring(const std::string& str, const std::string& sub) {
+  int count = 0;
+  std::string::size_type pos = 0;
+  while ((pos = str.find(sub, pos)) != std::string::npos) {
+    ++count;
+    pos += sub.length();
+  }
+  return count;
+}
+
+/*!
+ * \brief Transfer string to.
+ * \param str input string.
+ *
+ * \return bitset.
+ */
+inline std::bitset<128> Str2Set(const std::string& str) {
+  std::bitset<128> ret;
+  for (const char& c : str) {
+    ret.set(static_cast<int>(c));
+  }
+  return ret;
+}
+
+/*!
+ * \brief Split str according to substring.
+ * \param str input string.
+ * \param sub the split pattern string.
+ *
+ * \return vector contains the splited substring.
+ */
+inline std::vector<std::string> Split(const std::string& str, const std::string& sub) {
+  std::string::size_type pos = 0;
+  std::string::size_type start = 0;
+  std::vector<std::string> ret;
+  while ((pos = str.find(sub, start)) != std::string::npos) {
+    ret.push_back(str.substr(start, pos - start));
+    start = pos + sub.length();
+  }
+  ret.push_back(str.substr(start));
+  return ret;
+}
+
+/*!
+ * \brief Parse the input subscripts into a vector of strings.
+ * \param subscripts input subscripts.
+ * \param operands operand tensors.
+ *
+ * \return vector of strings, vector[0] represents the input part, vector[1] represents the ouput.
+ * if no output, the vector[1] is NULL.
+ * "ab, bc -> ac" => ["ab,bc", "ac"]
+ */
+inline std::tuple<std::string, std::string> ParseEinsumInput(
+    std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
+  const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
+  std::bitset<128> einsum_symbols_set;
+  for (const char& c : einsum_symbols) {
+    einsum_symbols_set.set(c);
+  }
+
+  CHECK_NE(operands.size(), 0U) << "No input operands";
+
+  auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' ');
+  subscripts.erase(end_pos, subscripts.end());
+
+  // Ensure all characters are valid
+  for (const char& c : subscripts) {
+    if (c == '.' || c == ',' || c == '-' || c == '>') {
+      continue;
+    }
+    CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
+  }
+
+  // Check for proper "->"
+  if (subscripts.find('-') != std::string::npos || subscripts.find('>') != std::string::npos) {
+    bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 ||
+                    std::count(subscripts.begin(), subscripts.end(), '>') > 1);
+    CHECK(!invalid && CountSubstring(subscripts, "->") == 1)
+        << "Subscripts can only contain one '->'.";
+  }
+
+  // Parse ellipses
+  if (subscripts.find('.') != std::string::npos) {
+    std::string used = subscripts;
+    used.erase(
+        std::remove_if(used.begin(), used.end(),
+                       [](const char& c) { return c == '.' || c == ',' || c == '-' || c == '>'; }),
+        used.end());
+
+    std::bitset<128> used_set = Str2Set(used);
+    std::string ellipse_inds = "";
+    for (const char& c : einsum_symbols) {
+      if (!used_set.test(static_cast<int>(c))) {
+        ellipse_inds.append(1, c);
+      }
+    }
+    int longest = 0;
+    std::string input_tmp, output_sub;
+    std::vector<std::string> split_subscripts;
+    bool out_sub;
+
+    if (subscripts.find("->") != std::string::npos) {
+      std::vector<std::string> tmp = Split(subscripts, "->");
+      input_tmp = tmp[0];
+      output_sub = tmp[1];
+      split_subscripts = Split(input_tmp, ",");
+      out_sub = true;
+    } else {
+      split_subscripts = Split(subscripts, ",");
+      out_sub = false;
+    }
+
+    size_t size_split_subscripts = split_subscripts.size();
+    subscripts = "";
+    for (size_t i = 0; i < size_split_subscripts; ++i) {
+      const std::string& sub = split_subscripts[i];
+      if (sub.find('.') != std::string::npos) {
+        CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) << "Invalid Ellipses";
+        CHECK_EQ(CountSubstring(sub, "..."), 1) << "Invalid Ellipses";
+
+        // Take into account numerical values
+        int ellipse_count = 0;
+        if (operands[i].size() == 0) {
+          ellipse_count = 0;
+        } else {
+          ellipse_count = std::max(operands[i].size(), static_cast<size_t>(1));
+          ellipse_count -= sub.length() - 3;
+        }
+
+        if (ellipse_count > longest) {
+          longest = ellipse_count;
+        }
+
+        CHECK_GE(ellipse_count, 0) << "Ellipses lengths do not match.";
+        if (ellipse_count == 0) {
+          split_subscripts[i].erase(sub.find("..."), 3);
+        } else {
+          std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count);
+          split_subscripts[i].replace(sub.find("..."), 3, rep_inds);
+        }
+      }
+      subscripts += split_subscripts[i];
+      if (i + 1 < size_split_subscripts) {
+        subscripts += ",";
+      }
+    }
+    std::string out_ellipse;
+    if (longest == 0) {
+      out_ellipse = "";
+    } else {
+      out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
+    }
+
+    if (out_sub) {
+      output_sub.replace(output_sub.find("..."), 3, out_ellipse);
+      subscripts += "->" + output_sub;
+    } else {
+      // Special care for outputless ellipses
+      std::bitset<128> out_ellipse_set = Str2Set(out_ellipse);
+      std::string tmp_subscripts = subscripts, output_subscript = "";
+      size_t len_tmp_subscripts = tmp_subscripts.length();
+      std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
+      for (size_t i = 0; i < len_tmp_subscripts; ++i) {
+        const char& c = tmp_subscripts[i];
+        if (c == ',') {
+          continue;
+        }
+        CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
+        if ((i == 0 || tmp_subscripts[i - 1] != c) &&
+            (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
+            !out_ellipse_set.test(c)) {
+          output_subscript.append(1, c);
+        }
+      }
+      subscripts += "->" + out_ellipse + output_subscript;
+    }
+  }
+
+  // Build output string if does not exist
+  std::tuple<std::string, std::string> ret;
+  if (subscripts.find("->") != std::string::npos) {
+    std::vector<std::string> tmp(2);
+    tmp = Split(subscripts, "->");
+    ret = std::make_tuple(tmp[0], tmp[1]);
+  } else {
+    std::string first = subscripts;
+    std::string second = "";
+    // Build output subscripts
+    std::string tmp_subscripts = subscripts;
+    size_t len_tmp_subscripts = tmp_subscripts.length();
+    std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
+    for (size_t i = 0; i < len_tmp_subscripts; ++i) {
+      const char& c = tmp_subscripts[i];
+      if (c == ',') {
+        continue;
+      }
+      CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
+      if ((i == 0 || tmp_subscripts[i - 1] != c) &&
+          (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
+        second.append(1, c);
+      }
+    }
+    ret = std::make_tuple(first, second);
+  }
+
+  // Make sure output subscripts are in the input
+  std::bitset<128> input_subscripts_set = Str2Set(std::get<0>(ret));
+  for (const char& c : std::get<1>(ret)) {
+    CHECK(input_subscripts_set.test(c))
+        << "Output character " << c << " did not appear in the input";
+  }
+
+  // Make sure number operands is equivalent to the number of terms
+  CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') + 1, operands.size())
+      << "Number of einsum subscripts must be equal to the "
+      << "number of operands.";
+
+  return ret;
+}
+
+/*!
+ * \brief Compute the shape of the output.
+ * \param subscripts input subscripts.
+ * \param operands operand tensors.
+ *
+ * \return the shape of the output.
+ */
+inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
+                                        const std::vector<Array<PrimExpr>>& operands) {
+  // Parsing
+  std::tuple<std::string, std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);
+
+  // Build a few useful list and sets
+  std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts), ",");
+  size_t isize = input_list.size();
+
+  // Get length of each unique dimension and ensure all dimensions are correct
+  int dimension_dict[128];
+  memset(dimension_dict, -1, sizeof(dimension_dict));
+  for (size_t i = 0; i < isize; ++i) {
+    const std::string& term = input_list[i];
+    const Array<PrimExpr>& sh = operands[i];
+    CHECK_EQ(sh.size(), term.length())
+        << "Einstein sum subscript " << input_list[i] << " does not contain the "
+        << "correct number of indices for operand " << i << ".";
+    size_t len_term = term.length();
+    for (size_t j = 0; j < len_term; ++j) {
+      int64_t dim = GetConstInt(sh[j]);
+      const char& c = term[j];
+
+      if (dimension_dict[static_cast<int>(c)] != -1) {
+        // For broadcasting cases we always want the largest dim size
+        if (dimension_dict[static_cast<int>(c)] == 1) {
+          dimension_dict[static_cast<int>(c)] = dim;
+        }
+        CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
+            << "Size of label '" << c << "' for operand  " << i << " ("
+            << dimension_dict[static_cast<int>(c)] << ") does not match previous terms (" << dim
+            << ").";
+      } else {
+        dimension_dict[static_cast<int>(c)] = dim;
+      }
+    }
+  }
+
+  // Get oshape
+  const std::string& output_str = std::get<1>(parsed_subscripts);
+  size_t odim = output_str.size();
+  Array<PrimExpr> oshape(odim, -1);
+  for (size_t i = 0; i < odim; ++i) {
+    oshape.Set(i, dimension_dict[static_cast<int>(output_str[i])]);
+  }
+  // Neglecting oshape assign check temporally
+  return oshape;
+}
+
+/*!
+ * \brief Evaluates the Einstein summation convention on the operands.
+ *
+ * \param subscripts_str Specifies the subscripts for summation as comma separated list of
+ * subscript labels. \param inputs Arrays for the operation. \param name The name of the
+ * operation. \param tag The tag to mark the operation.
+ *
+ * \return The calculation based on the Einstein summation convention.
+ */
+inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs,
+                     std::string name = "T_einsum", std::string tag = kMatMul) {

Review comment:
       Created a new tag called kEinsum.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org