hanke580 commented on a change in pull request #6370:
URL: https://github.com/apache/incubator-tvm/pull/6370#discussion_r482177355



##########
File path: include/tvm/topi/transform.h
##########
@@ -1281,6 +1285,832 @@ inline Tensor tensordot(const Tensor& A, const 
tvm::te::Tensor& B, Array<PrimExp
   return compute(output_shape, func, name, tag);
 }
 
+inline Array<PrimExpr> get_stride(const Array<PrimExpr> shape) {
+  size_t ndim = shape.size();
+  int prod = 1;
+  Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
+  for (int i = ndim - 1; i >= 0; i--) {
+    stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
+    prod = prod * GetConstInt(shape[i]);
+  }
+  return stride;
+}
+
+inline Array<PrimExpr> pad(const Array<PrimExpr> shape, int odim) {
+  int ndim = shape.size();
+  CHECK_GE(odim, ndim);
+  Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
+  for (int idim = 0; idim < ndim; ++idim) {
+    ret.Set(idim, shape[idim]);
+  }
+  return ret;
+}
+
+inline int parse_operand_subscripts(const char *subscripts, int length,
+                                    int ndim, int iop, char *op_labels,
+                                    char *label_counts, int *min_label, int 
*max_label) {
+  int i;
+  int idim = 0;
+  int ellipsis = -1;
+
+  /* Process all labels for this operand */
+  for (i = 0; i < length; ++i) {
+    int label = subscripts[i];
+
+    /* A proper label for an axis. */
+    if (label > 0 && isalpha(label)) {
+      /* Check we don't exceed the operator dimensions. */
+      CHECK(idim < ndim)
+        << "einstein sum subscripts string contains "
+        << "too many subscripts for operand "
+        << iop;
+
+      op_labels[idim++] = label;
+      if (label < *min_label) {
+        *min_label = label;
+      }
+      if (label > *max_label) {
+        *max_label = label;
+      }
+      label_counts[label]++;
+    } else if (label == '.') {
+      /* The beginning of the ellipsis. */
+      /* Check it's a proper ellipsis. */
+      CHECK(!(ellipsis != -1 || i + 2 >= length
+              || subscripts[++i] != '.' || subscripts[++i] != '.'))
+        << "einstein sum subscripts string contains a "
+        << "'.' that is not part of an ellipsis ('...') "
+        << "in operand "
+        << iop;
+
+      ellipsis = idim;
+    } else {
+        CHECK(label == ' ')
+          << "invalid subscript '" << static_cast<char>(label)
+          << "' in einstein sum "
+          << "subscripts string, subscripts must "
+          << "be letters";
+    }
+  }
+
+  /* No ellipsis found, labels must match dimensions exactly. */
+  if (ellipsis == -1) {
+    CHECK(idim == ndim)
+      << "operand has more dimensions than subscripts "
+      << "given in einstein sum, but no '...' ellipsis "
+      << "provided to broadcast the extra dimensions.";
+  } else if (idim < ndim) {
+    /* Ellipsis found, may have to add broadcast dimensions. */
+    /* Move labels after ellipsis to the end. */
+    for (i = 0; i < idim - ellipsis; ++i) {
+      op_labels[ndim - i - 1] = op_labels[idim - i - 1];
+    }
+    /* Set all broadcast dimensions to zero. */
+    for (i = 0; i < ndim - idim; ++i) {
+      op_labels[ellipsis + i] = 0;
+    }
+  }
+
+  /*
+   * Find any labels duplicated for this operand, and turn them
+   * into negative offsets to the axis to merge with.
+   *
+   * In C, the char type may be signed or unsigned, but with
+   * twos complement arithmetic the char is ok either way here, and
+   * later where it matters the char is cast to a signed char.
+   */
+  for (idim = 0; idim < ndim - 1; ++idim) {
+    int label = op_labels[idim];
+    /* If it is a proper label, find any duplicates of it. */

Review comment:
       Created a header file "einsum.h"




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to