This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 618ef9e [TOPI] Add einsum operator (#6370)
618ef9e is described below
commit 618ef9e87a833ea1e9fc69b013f8d181a530b7b8
Author: Ke Han <[email protected]>
AuthorDate: Thu Feb 4 02:13:40 2021 +0800
[TOPI] Add einsum operator (#6370)
* [TOPI] Einsum
* Fix tuple
* fix oshape
* * test
* * Fix lint
* * Remove useless define
* * Move to einsum header file
* * Fix single value situation
* * Fix CamelASE
* * Print stride
* * Fix single input bug
* * fix lint
* * Fix lint and add comments
* * create test einsum
* * Fix lint
* * Fix comments
---
include/tvm/topi/einsum.h | 943 +++++++++++++++++++++++++++
include/tvm/topi/tags.h | 1 +
python/tvm/topi/__init__.py | 1 +
python/tvm/topi/einsum.py | 44 ++
src/topi/transform.cc | 5 +
tests/python/topi/python/test_topi_einsum.py | 78 +++
tests/python/unittest/test_te_autodiff.py | 4 +
7 files changed, 1076 insertions(+)
diff --git a/include/tvm/topi/einsum.h b/include/tvm/topi/einsum.h
new file mode 100644
index 0000000..e1baada
--- /dev/null
+++ b/include/tvm/topi/einsum.h
@@ -0,0 +1,943 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file topi/einsum.h
+ * \brief Einstein summation op
+ */
+#ifndef TVM_TOPI_EINSUM_H_
+#define TVM_TOPI_EINSUM_H_
+
+#define LABELRANGE 128
+#define NPY_MAXDIMS 16
+#define NPY_MAXARGS 16
+
+#include <tvm/te/operation.h>
+#include <tvm/tir/data_layout.h>
+#include <tvm/topi/detail/constant_utils.h>
+#include <tvm/topi/detail/ravel_unravel.h>
+#include <tvm/topi/detail/tensor_utils.h>
+#include <tvm/topi/tags.h>
+
+#include <algorithm>
+#include <bitset>
+#include <iterator>
+#include <string>
+#include <tuple>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace topi {
+
+using namespace tvm::te;
+using namespace topi::detail;
+
+/*!
+ * \brief Compute the stride of the given shape.
+ *
+ * \param shape for the operation.
+ *
+ * \return the stride of the shape.
+ */
+inline Array<PrimExpr> GetStride(const Array<PrimExpr> shape) {
+ size_t ndim = shape.size();
+ int prod = 1;
+ Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
+ for (int i = ndim - 1; i >= 0; i--) {
+ stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
+ prod = prod * GetConstInt(shape[i]);
+ }
+ return stride;
+}
+
+/*!
+ * \brief Pad the shape with 1.
+ *
+ * \param shape the input shape to be padded
+ * \param odim the padding size of the objective shape.
+ *
+ * \return the padded shape.
+ */
+inline Array<PrimExpr> Pad(const Array<PrimExpr> shape, int odim) {
+ int ndim = shape.size();
+ CHECK_GE(odim, ndim);
+ Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
+ for (int idim = 0; idim < ndim; ++idim) {
+ ret.Set(idim, shape[idim]);
+ }
+ return ret;
+}
+
+/*!
+ * \brief Parse the subscripts for one operand into an output of 'ndim' labels.
+ *
+ * \param subscripts the subscripts for to be parsed.
+ * \param length subscripts[0: length] represents the current operand.
+ * \param ndim the ndim of current operand.
+ * \param iop the index of the operand.
+ * \param op_labels the parsing result.
+ * For Example:
+ * subscripts="abbcbc", ndim=6 -> op_labels=[97, 98, -1, 99, -3,
-2].
+ * subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99].
+ * \param label_counts Count the number the label appears.
+ * \param min_label Save the minimal label according to ASCII.
+ * \param max_label Save the maximal label according to ASCII.
+ *
+ * \return 0.
+ */
+inline int ParseOperandSubscripts(const char* subscripts, int length, int
ndim, int iop,
+ char* op_labels, char* label_counts, int*
min_label,
+ int* max_label) {
+ int i;
+ int idim = 0;
+ int ellipsis = -1;
+
+ /* Process all labels for this operand */
+ for (i = 0; i < length; ++i) {
+ int label = subscripts[i];
+
+ /* A proper label for an axis. */
+ if (label > 0 && isalpha(label)) {
+ /* Check we don't exceed the operator dimensions. */
+ CHECK(idim < ndim) << "einstein sum subscripts string contains "
+ << "too many subscripts for operand " << iop;
+
+ op_labels[idim++] = label;
+ if (label < *min_label) {
+ *min_label = label;
+ }
+ if (label > *max_label) {
+ *max_label = label;
+ }
+ label_counts[label]++;
+ } else if (label == '.') {
+ /* The beginning of the ellipsis. */
+ /* Check it's a proper ellipsis. */
+ CHECK(
+ !(ellipsis != -1 || i + 2 >= length || subscripts[++i] != '.' ||
subscripts[++i] != '.'))
+ << "einstein sum subscripts string contains a "
+ << "'.' that is not part of an ellipsis ('...') "
+ << "in operand " << iop;
+
+ ellipsis = idim;
+ } else {
+ CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
+ << "' in einstein sum "
+ << "subscripts string, subscripts must "
+ << "be letters";
+ }
+ }
+
+ /* No ellipsis found, labels must match dimensions exactly. */
+ if (ellipsis == -1) {
+ CHECK(idim == ndim) << "operand has more dimensions than subscripts "
+ << "given in einstein sum, but no '...' ellipsis "
+ << "provided to broadcast the extra dimensions.";
+ } else if (idim < ndim) {
+ /* Ellipsis found, may have to add broadcast dimensions. */
+ /* Move labels after ellipsis to the end. */
+ for (i = 0; i < idim - ellipsis; ++i) {
+ op_labels[ndim - i - 1] = op_labels[idim - i - 1];
+ }
+ /* Set all broadcast dimensions to zero. */
+ for (i = 0; i < ndim - idim; ++i) {
+ op_labels[ellipsis + i] = 0;
+ }
+ }
+
+ /*
+ * Find any labels duplicated for this operand, and turn them
+ * into negative offsets to the axis to merge with.
+ *
+ * In C, the char type may be signed or unsigned, but with
+ * twos complement arithmetic the char is ok either way here, and
+ * later where it matters the char is cast to a signed char.
+ */
+ for (idim = 0; idim < ndim - 1; ++idim) {
+ int label = op_labels[idim];
+ /* If it is a proper label, find any duplicates of it. */
+ if (label > 0) {
+ /* Search for the next matching label. */
+ char* next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label,
ndim - idim - 1));
+
+ while (next != nullptr) {
+ /* The offset from next to op_labels[idim] (negative). */
+ *next = static_cast<char>((op_labels + idim) - next);
+ /* Search for the next matching label. */
+ next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels +
ndim - 1 - next));
+ }
+ }
+ }
+ return 0;
+}
+
+/*!
+ * \brief Parse the subscripts for the output into an output that includes
'ndim_broadcast'
+ * unlabeled dimensions.
+ *
+ * \param subscripts the subscripts for to be parsed.
+ * \param length subscripts[0: length] represents the output operand.
+ * \param ndim_broadcast the broadcast dimension number.
+ * \param label_counts Count the number the label appears.
+ * \param out_labels similar to the op_labels in ParseOperandSubscripts, for
each
+ * dimension, the ASCII code of the corresponding label. zero for the
broadcasting dim.
+ *
+ * \return the total number of output dimensions or -1 if there is an error.
+ */
+inline int ParseOutputSubscripts(const char* subscripts, int length, int
ndim_broadcast,
+ const char* label_counts, char* out_labels) {
+ int i, bdim;
+ int ndim = 0;
+ int ellipsis = 0;
+
+ /* Process all the output labels. */
+ for (i = 0; i < length; ++i) {
+ int label = subscripts[i];
+
+ /* A proper label for an axis. */
+ if (label > 0 && isalpha(label)) {
+ /* Check that it doesn't occur again. */
+ CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
+ << "einstein sum subscripts string includes "
+ << "output subscript '" << static_cast<char>(label) << "' multiple
times";
+
+ /* Check that it was used in the inputs. */
+ CHECK(label_counts[label] != 0)
+ << "einstein sum subscripts string included "
+ << "output subscript '" << static_cast<char>(label) << "' which
never appeared "
+ << "in an input";
+
+ /* Check that there is room in out_labels for this label. */
+ CHECK(ndim < NPY_MAXDIMS) << "einstein sum subscripts string contains "
+ << "too many subscripts in the output";
+
+ out_labels[ndim++] = label;
+ } else if (label == '.') {
+ /* The beginning of the ellipsis. */
+ /* Check it is a proper ellipsis. */
+ CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] != '.' ||
subscripts[++i] != '.'))
+ << "einstein sum subscripts string "
+ << "contains a '.' that is not part of "
+ << "an ellipsis ('...') in the output";
+
+ /* Check there is room in out_labels for broadcast dims. */
+ CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) << "einstein sum subscripts
string contains "
+ << "too many subscripts in
the output";
+
+ ellipsis = 1;
+ for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
+ out_labels[ndim++] = 0;
+ }
+ } else {
+ CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
+ << "' in einstein sum "
+ << "subscripts string, subscripts must "
+ << "be letters";
+ }
+ }
+
+ /* If no ellipsis was found there should be no broadcast dimensions. */
+ CHECK(!(!ellipsis && ndim_broadcast > 0)) << "output has more dimensions
than subscripts "
+ << "given in einstein sum, but no
'...' ellipsis "
+ << "provided to broadcast the
extra dimensions.";
+
+ return ndim;
+}
+
+/*!
+ * \brief If any dimensions are combined, create a view that combines them.
+ * Shows in newshape and newstride.
+ *
+ * \param op the operand tensor.
+ * \param iop the index of the operand.
+ * \param labels the op_labels fot the operand. Like [97, 98, -2] for "aba".
+ * \param newshape The combined shape.
+ * \param newstride The combined stride.
+ *
+ * For example:
+ * "aba -> ab", shape = [2,3,2] stride = [6,2,1]
+ * op_labels = [97, 98, -2], newshape = [2,3], newstride = [7,2]
+ */
+inline void GetCombinedDimsView(const Tensor& op, int iop, char* labels,
Array<PrimExpr>* newshape,
+ Array<PrimExpr>* newstride) {
+ int idim, ndim, icombine, combineoffset;
+ int icombinemap[NPY_MAXDIMS];
+ int newdim;
+
+ Array<PrimExpr> shape = op->shape;
+ Array<PrimExpr> stride = GetStride(shape);
+ ndim = op.ndim();
+ newdim = newshape->size();
+
+ /* Initialize the dimensions and strides to zero */
+ for (idim = 0; idim < newdim; ++idim) {
+ newshape->Set(idim, 0);
+ newstride->Set(idim, 0);
+ }
+
+ /* Copy the dimensions and strides, except when collapsing */
+ icombine = 0;
+ for (idim = 0; idim < ndim; ++idim) {
+ /*
+ * The char type may be either signed or unsigned, we
+ * need it to be signed here.
+ */
+ int label = (signed char)labels[idim];
+ /* If this label says to merge axes, get the actual label */
+ if (label < 0) {
+ combineoffset = label;
+ label = labels[idim + label];
+ } else {
+ combineoffset = 0;
+ if (icombine != idim) {
+ labels[icombine] = labels[idim];
+ }
+ icombinemap[idim] = icombine;
+ }
+ /* If the label is 0, it's an unlabeled broadcast dimension */
+ if (label == 0) {
+ newshape->Set(icombine, shape[idim]);
+ newstride->Set(icombine, stride[idim]);
+ } else {
+ /* Update the combined axis dimensions and strides */
+ int i = icombinemap[idim + combineoffset];
+ CHECK(!((combineoffset < 0) &&
+ GetConstInt((*newshape)[i] != 0 && (*newshape)[i] !=
shape[idim])))
+ << "dimensions in operand " << iop << " for collapsing index '" <<
label
+ << "' don't match (" << GetConstInt((*newshape)[i]) << " != " <<
shape[idim] << ")";
+ newshape->Set(i, shape[idim]);
+ newstride->Set(i, (*newstride)[i] + stride[idim]);
+ }
+
+ /* If the label didn't say to combine axes, increment dest i */
+ if (combineoffset == 0) {
+ icombine++;
+ }
+ }
+}
+
+/*!
+ * \brief Prepare the operand axes to match each stride or shape pair.
+ *
+ * \param ndim the ndim of the operand tensor.
+ * \param iop the index of the operand.
+ * \param labels the op_labels fot the operand. [97, 98, -1, 99, -3, -2] for
"abbcbc".
+ * \param axes The matched axes to be calculated.
+ * \param ndim_iter the dimension of iterating. Subscripts "ab, bc -> ac"
ndim_iter = 3.
+ * \param iter_labels output_labels with the iterating label. ['a', 'c', 'b']
for the case above.
+ */
+inline static int PrepareOpAxes(int ndim, int iop, char* labels, int* axes,
int ndim_iter,
+ char* iter_labels) {
+ int i, label, ibroadcast;
+
+ ibroadcast = ndim - 1;
+ for (i = ndim_iter - 1; i >= 0; --i) {
+ label = iter_labels[i];
+ /*
+ * If it's an unlabeled broadcast dimension, choose
+ * the next broadcast dimension from the operand.
+ */
+ if (label == 0) {
+ while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
+ --ibroadcast;
+ }
+ /*
+ * If we used up all the operand broadcast dimensions,
+ * extend it with a "newaxis"
+ */
+ if (ibroadcast < 0) {
+ axes[i] = -1;
+ } else {
+ /* Otherwise map to the broadcast axis */
+ axes[i] = ibroadcast;
+ --ibroadcast;
+ }
+ } else {
+ /* It's a labeled dimension, find the matching one */
+ char* match = reinterpret_cast<char*>(memchr(labels, label, ndim));
+ /* If the op doesn't have the label, broadcast it */
+ if (match == nullptr) {
+ axes[i] = -1;
+ } else {
+ /* Otherwise use it */
+ axes[i] = match - labels;
+ }
+ }
+ }
+ return 0;
+}
+
+/*!
+ * \brief Count SubString.
+ * \param str the object string
+ * \param sub the pattern string
+ *
+ * \return number of substring
+ */
+inline int CountSubstring(const std::string& str, const std::string& sub) {
+ int count = 0;
+ std::string::size_type pos = 0;
+ while ((pos = str.find(sub, pos)) != std::string::npos) {
+ ++count;
+ pos += sub.length();
+ }
+ return count;
+}
+
+/*!
+ * \brief Transfer string to.
+ * \param str input string.
+ *
+ * \return bitset.
+ */
+inline std::bitset<LABELRANGE> Str2Set(const std::string& str) {
+ std::bitset<LABELRANGE> ret;
+ for (const char& c : str) {
+ ret.set(static_cast<int>(c));
+ }
+ return ret;
+}
+
+/*!
+ * \brief Split str according to substring.
+ * \param str input string.
+ * \param sub the split pattern string.
+ *
+ * \return vector contains the splited substring.
+ */
+inline std::vector<std::string> Split(const std::string& str, const
std::string& sub) {
+ std::string::size_type pos = 0;
+ std::string::size_type start = 0;
+ std::vector<std::string> ret;
+ while ((pos = str.find(sub, start)) != std::string::npos) {
+ ret.push_back(str.substr(start, pos - start));
+ start = pos + sub.length();
+ }
+ ret.push_back(str.substr(start));
+ return ret;
+}
+
+/*!
+ * \brief Parse the input subscripts into a vector of strings.
+ * \param subscripts input subscripts.
+ * \param operands operand tensors.
+ *
+ * \return vector of strings, vector[0] represents the input part, vector[1]
represents the ouput.
+ * if no output, the vector[1] is NULL.
+ * "ab, bc -> ac" => ["ab,bc", "ac"]
+ */
+inline std::tuple<std::string, std::string> ParseEinsumInput(
+ std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
+ const std::string einsum_symbols =
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
+ std::bitset<LABELRANGE> einsum_symbols_set;
+ for (const char& c : einsum_symbols) {
+ einsum_symbols_set.set(c);
+ }
+
+ CHECK_NE(operands.size(), 0U) << "No input operands";
+
+ auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' ');
+ subscripts.erase(end_pos, subscripts.end());
+
+ // Ensure all characters are valid
+ for (const char& c : subscripts) {
+ if (c == '.' || c == ',' || c == '-' || c == '>') {
+ continue;
+ }
+ CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid
symbol.";
+ }
+
+ // Check for proper "->"
+ if (subscripts.find('-') != std::string::npos || subscripts.find('>') !=
std::string::npos) {
+ bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1
||
+ std::count(subscripts.begin(), subscripts.end(), '>') > 1);
+ CHECK(!invalid && CountSubstring(subscripts, "->") == 1)
+ << "Subscripts can only contain one '->'.";
+ }
+
+ // Parse ellipses
+ if (subscripts.find('.') != std::string::npos) {
+ std::string used = subscripts;
+ used.erase(
+ std::remove_if(used.begin(), used.end(),
+ [](const char& c) { return c == '.' || c == ',' || c ==
'-' || c == '>'; }),
+ used.end());
+
+ std::bitset<LABELRANGE> used_set = Str2Set(used);
+ std::string ellipse_inds = "";
+ for (const char& c : einsum_symbols) {
+ if (!used_set.test(static_cast<int>(c))) {
+ ellipse_inds.append(1, c);
+ }
+ }
+ int longest = 0;
+ std::string input_tmp, output_sub;
+ std::vector<std::string> split_subscripts;
+ bool out_sub;
+
+ if (subscripts.find("->") != std::string::npos) {
+ std::vector<std::string> tmp = Split(subscripts, "->");
+ input_tmp = tmp[0];
+ output_sub = tmp[1];
+ split_subscripts = Split(input_tmp, ",");
+ out_sub = true;
+ } else {
+ split_subscripts = Split(subscripts, ",");
+ out_sub = false;
+ }
+
+ size_t size_split_subscripts = split_subscripts.size();
+ subscripts = "";
+ for (size_t i = 0; i < size_split_subscripts; ++i) {
+ const std::string& sub = split_subscripts[i];
+ if (sub.find('.') != std::string::npos) {
+ CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) << "Invalid
Ellipses";
+ CHECK_EQ(CountSubstring(sub, "..."), 1) << "Invalid Ellipses";
+
+ // Take into account numerical values
+ int ellipse_count = 0;
+ if (operands[i].size() == 0) {
+ ellipse_count = 0;
+ } else {
+ ellipse_count = std::max(operands[i].size(), static_cast<size_t>(1));
+ ellipse_count -= sub.length() - 3;
+ }
+
+ if (ellipse_count > longest) {
+ longest = ellipse_count;
+ }
+
+ CHECK_GE(ellipse_count, 0) << "Ellipses lengths do not match.";
+ if (ellipse_count == 0) {
+ split_subscripts[i].erase(sub.find("..."), 3);
+ } else {
+ std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() -
ellipse_count);
+ split_subscripts[i].replace(sub.find("..."), 3, rep_inds);
+ }
+ }
+ subscripts += split_subscripts[i];
+ if (i + 1 < size_split_subscripts) {
+ subscripts += ",";
+ }
+ }
+ std::string out_ellipse;
+ if (longest == 0) {
+ out_ellipse = "";
+ } else {
+ out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
+ }
+
+ if (out_sub) {
+ output_sub.replace(output_sub.find("..."), 3, out_ellipse);
+ subscripts += "->" + output_sub;
+ } else {
+ // Special care for outputless ellipses
+ std::bitset<LABELRANGE> out_ellipse_set = Str2Set(out_ellipse);
+ std::string tmp_subscripts = subscripts, output_subscript = "";
+ size_t len_tmp_subscripts = tmp_subscripts.length();
+ std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
+ for (size_t i = 0; i < len_tmp_subscripts; ++i) {
+ const char& c = tmp_subscripts[i];
+ if (c == ',') {
+ continue;
+ }
+ CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a
valid symbol.";
+ if ((i == 0 || tmp_subscripts[i - 1] != c) &&
+ (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
+ !out_ellipse_set.test(c)) {
+ output_subscript.append(1, c);
+ }
+ }
+ subscripts += "->" + out_ellipse + output_subscript;
+ }
+ }
+
+ // Build output string if does not exist
+ std::tuple<std::string, std::string> ret;
+ if (subscripts.find("->") != std::string::npos) {
+ std::vector<std::string> tmp(2);
+ tmp = Split(subscripts, "->");
+ ret = std::make_tuple(tmp[0], tmp[1]);
+ } else {
+ std::string first = subscripts;
+ std::string second = "";
+ // Build output subscripts
+ std::string tmp_subscripts = subscripts;
+ size_t len_tmp_subscripts = tmp_subscripts.length();
+ std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
+ for (size_t i = 0; i < len_tmp_subscripts; ++i) {
+ const char& c = tmp_subscripts[i];
+ if (c == ',') {
+ continue;
+ }
+ CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a
valid symbol.";
+ if ((i == 0 || tmp_subscripts[i - 1] != c) &&
+ (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
+ second.append(1, c);
+ }
+ }
+ ret = std::make_tuple(first, second);
+ }
+
+ // Make sure output subscripts are in the input
+ std::bitset<LABELRANGE> input_subscripts_set = Str2Set(std::get<0>(ret));
+ for (const char& c : std::get<1>(ret)) {
+ CHECK(input_subscripts_set.test(c))
+ << "Output character " << c << " did not appear in the input";
+ }
+
+ // Make sure number operands is equivalent to the number of terms
+ CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') +
1, operands.size())
+ << "Number of einsum subscripts must be equal to the "
+ << "number of operands.";
+
+ return ret;
+}
+
+/*!
+ * \brief Compute the shape of the output.
+ * \param subscripts input subscripts.
+ * \param operands operand tensors.
+ *
+ * \return the shape of the output.
+ */
+inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
+ const std::vector<Array<PrimExpr>>&
operands) {
+ // Parsing
+ std::tuple<std::string, std::string> parsed_subscripts =
ParseEinsumInput(subscripts, operands);
+
+ // Build a few useful list and sets
+ std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts),
",");
+ size_t isize = input_list.size();
+
+ // Get length of each unique dimension and ensure all dimensions are correct
+ int dimension_dict[LABELRANGE];
+ memset(dimension_dict, -1, sizeof(dimension_dict));
+ for (size_t i = 0; i < isize; ++i) {
+ const std::string& term = input_list[i];
+ const Array<PrimExpr>& sh = operands[i];
+ CHECK_EQ(sh.size(), term.length())
+ << "Einstein sum subscript " << input_list[i] << " does not contain
the "
+ << "correct number of indices for operand " << i << ".";
+ size_t len_term = term.length();
+ for (size_t j = 0; j < len_term; ++j) {
+ int64_t dim = GetConstInt(sh[j]);
+ const char& c = term[j];
+
+ if (dimension_dict[static_cast<int>(c)] != -1) {
+ // For broadcasting cases we always want the largest dim size
+ if (dimension_dict[static_cast<int>(c)] == 1) {
+ dimension_dict[static_cast<int>(c)] = dim;
+ }
+ CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
+ << "Size of label '" << c << "' for operand " << i << " ("
+ << dimension_dict[static_cast<int>(c)] << ") does not match
previous terms (" << dim
+ << ").";
+ } else {
+ dimension_dict[static_cast<int>(c)] = dim;
+ }
+ }
+ }
+
+ // Get oshape
+ const std::string& output_str = std::get<1>(parsed_subscripts);
+ size_t odim = output_str.size();
+ Array<PrimExpr> oshape(odim, -1);
+ for (size_t i = 0; i < odim; ++i) {
+ oshape.Set(i, dimension_dict[static_cast<int>(output_str[i])]);
+ }
+ // Neglecting oshape assign check temporally
+ return oshape;
+}
+
+/*!
+ * \brief Evaluates the Einstein summation convention on the operands.
+ *
+ * \param subscripts_str Specifies the subscripts for summation as comma
separated list of
+ * subscript labels.
+ * \param inputs Arrays for the operation.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return The calculation based on the Einstein summation convention.
+ */
+inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor>
inputs,
+ std::string name = "T_einsum", std::string tag = kEinsum)
{
+ bool back = false;
+ const char* subscripts = subscripts_str.data();
+ const char* head = subscripts;
+ const int nop = inputs.size();
+
+ /* Step 1: Parse the subscripts string into label_counts and op_labels */
+ int iop, idim, min_label = LABELRANGE - 1, max_label = 0;
+ char label_counts[LABELRANGE], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
+ memset(label_counts, 0, sizeof(label_counts));
+ for (iop = 0; iop < nop; ++iop) {
+ int length = static_cast<int>(strcspn(subscripts, ",-"));
+
+ CHECK(!(iop == nop - 1 && subscripts[length] == ','))
+ << "more operands provided to einstein sum function "
+ << "than specified in the subscripts string";
+ CHECK(!(iop < nop - 1 && subscripts[length] != ','))
+ << "fewer operands provided to einstein sum function "
+ << "than specified in the subscripts string";
+ CHECK_EQ(ParseOperandSubscripts(subscripts, length, inputs[iop +
back].ndim(), iop,
+ op_labels[iop], label_counts, &min_label,
&max_label),
+ 0);
+
+ /* Move subscripts to the start of the labels for the next op */
+ subscripts += length;
+
+ if (iop < nop - 1) {
+ CHECK_LT(subscripts - head, subscripts_str.length()) << "subscripts out
of range";
+ subscripts++;
+ }
+ }
+ /*
+ * Find the number of broadcast dimensions, which is the maximum
+ * number of labels == 0 in an op_labels array.
+ */
+ int ndim_broadcast = 0;
+ for (iop = 0; iop < nop; ++iop) {
+ int count_zeros = 0;
+ int ndim;
+ char* labels = op_labels[iop];
+
+ ndim = inputs[iop + back].ndim();
+ for (idim = 0; idim < ndim; ++idim) {
+ if (labels[idim] == 0) {
+ ++count_zeros;
+ }
+ }
+
+ if (count_zeros > ndim_broadcast) {
+ ndim_broadcast = count_zeros;
+ }
+ }
+
+ /*
+ * If there is no output signature, fill output_labels and ndim_output
+ * using each label that appeared once, in alphabetical order.
+ */
+ int label, ndim_output;
+ char output_labels[NPY_MAXDIMS];
+ if (subscripts[0] == '\0') {
+ /* If no output was specified, always broadcast left, as usual. */
+ for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
+ output_labels[ndim_output] = 0;
+ }
+ for (label = min_label; label <= max_label; ++label) {
+ if (label_counts[label] == 1) {
+ CHECK(ndim_output < NPY_MAXDIMS) << "einstein sum subscript string has
too many "
+ << "distinct labels";
+ output_labels[ndim_output++] = label;
+ }
+ }
+ } else {
+ CHECK(subscripts[0] == '-' && subscripts[1] == '>') << "einstein sum
subscript string does not "
+ << "contain proper
'->' output specified";
+ subscripts += 2;
+
+ /* Parse the output subscript string. */
+ ndim_output = ParseOutputSubscripts(subscripts, strlen(subscripts),
ndim_broadcast,
+ label_counts, output_labels);
+ CHECK_GE(ndim_output, 0);
+ }
+
+ /*
+ * Step 2:
+ * Process all the input ops, combining dimensions into their
+ * diagonal where specified.
+ */
+ std::vector<Array<PrimExpr>> opshape(nop), opstride_true(nop);
+ for (iop = 0; iop < nop; ++iop) {
+ char* labels = op_labels[iop];
+ int combine, ndim;
+
+ ndim = inputs[iop + back].ndim();
+
+ /*
+ * Check whether any dimensions need to be combined
+ *
+ * The char type may be either signed or unsigned, we
+ * need it to be signed here.
+ */
+ combine = 0;
+ for (idim = 0; idim < ndim; ++idim) {
+ if ((signed char)labels[idim] < 0) {
+ combine++;
+ }
+ }
+ /* If any dimensions are combined, create a view which combines them */
+ if (combine) {
+ Array<PrimExpr> tshape(static_cast<size_t>(ndim - combine), -1);
+ Array<PrimExpr> tstride(static_cast<size_t>(ndim - combine), -1);
+ GetCombinedDimsView(inputs[iop + back], iop, labels, &tshape, &tstride);
+ opshape[iop] = tshape;
+ opstride_true[iop] = tstride;
+ } else {
+ /* No combining needed */
+ opshape[iop] = inputs[iop + back]->shape;
+ opstride_true[iop] = GetStride(opshape[iop]);
+ }
+ }
+ /*
+ * Step 3:
+ * Set up the labels for the iterator (output + combined labels).
+ * Can just share the output_labels memory, because iter_labels
+ * is output_labels with some more labels appended.
+ */
+ char* iter_labels = output_labels;
+ int ndim_iter = ndim_output;
+ for (label = min_label; label <= max_label; ++label) {
+ if (label_counts[label] > 0 && memchr(output_labels, label, ndim_output)
== nullptr) {
+ CHECK(ndim_iter < NPY_MAXDIMS) << "too many subscripts in einsum";
+ iter_labels[ndim_iter++] = label;
+ }
+ }
+ /* Step 4: Set up the op_axes for the iterator */
+ Array<PrimExpr> itershape(static_cast<size_t>(ndim_iter), -1);
+ std::vector<Array<PrimExpr>> iterstride(nop + 1,
+
Array<PrimExpr>(static_cast<size_t>(ndim_iter), 0));
+
+ // output_shape
+ std::vector<Array<PrimExpr>> operands;
+ for (size_t i = 0; i < inputs.size(); i++) {
+ operands.push_back(inputs[i]->shape);
+ }
+ Array<PrimExpr> oshape = NumpyEinsumShape(subscripts_str, operands);
+ Array<PrimExpr> ostride_true = GetStride(oshape);
+ Array<PrimExpr> reduceshape;
+ std::vector<Array<PrimExpr>> remainshape(nop);
+ int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
+ int* op_axes[NPY_MAXARGS];
+ for (iop = 0; iop < nop; ++iop) {
+ op_axes[iop] = op_axes_arrays[iop];
+ CHECK_GE(PrepareOpAxes(opshape[iop].size(), iop, op_labels[iop],
op_axes[iop], ndim_iter,
+ iter_labels),
+ 0);
+ for (idim = 0; idim < ndim_iter; idim++) {
+ if (op_axes[iop][idim] != -1) {
+ iterstride[iop].Set(idim, opstride_true[iop][op_axes[iop][idim]]);
+ if (GetConstInt(itershape[idim]) != -1) {
+ if (GetConstInt(itershape[idim]) == 1) {
+ itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
+ }
+ } else {
+ itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
+ }
+ }
+ }
+ }
+ for (idim = 0; idim < ndim_output; ++idim) {
+ iterstride[nop].Set(idim, ostride_true[idim]);
+ }
+ reduceshape = Array<PrimExpr>(static_cast<size_t>(ndim_iter - ndim_output),
0);
+ for (idim = ndim_output; idim < ndim_iter; ++idim) {
+ reduceshape.Set(idim - ndim_output, itershape[idim]);
+ }
+ for (iop = 0; iop < nop; iop++) {
+ Array<Integer> rsh;
+ for (idim = 0; idim < ndim_iter; idim++) {
+ if (op_axes_arrays[iop][idim] == -1) {
+ rsh.push_back(GetConstInt(itershape[idim]));
+ } else {
+ if (GetConstInt(itershape[idim] !=
opshape[iop][op_axes_arrays[iop][idim]])) {
+ rsh.push_back(GetConstInt(itershape[idim]));
+ }
+ }
+ }
+ remainshape[iop] = Array<PrimExpr>(rsh.begin(), rsh.end());
+ }
+ // exclude the 0-dim case
+ if (ndim_iter == 0) {
+ ndim_iter = 1;
+ }
+ itershape = Pad(itershape, ndim_iter);
+ for (iop = 0; iop <= nop; ++iop) {
+ iterstride[iop] = Pad(iterstride[iop], ndim_iter);
+ }
+ // oshape = Pad(oshape, ndim_iter);
+ reduceshape = Pad(reduceshape, ndim_iter);
+ for (iop = 0; iop < nop; ++iop) {
+ opshape[iop] = Pad(opshape[iop], ndim_iter);
+ remainshape[iop] = Pad(remainshape[iop], ndim_iter);
+ }
+ // ostride and rstride
+ Array<Array<PrimExpr>> ostride;
+ Array<Array<PrimExpr>> rstride;
+
+ for (iop = 0; iop < nop; ++iop) {
+ Array<PrimExpr> otmp(static_cast<size_t>(ndim_iter), 0);
+ Array<PrimExpr> rtmp(static_cast<size_t>(ndim_iter), 0);
+ for (idim = 0; idim < ndim_iter; ++idim) {
+ otmp.Set(idim, idim < ndim_output ? iterstride[iop][idim] : 1);
+ rtmp.Set(idim, idim < ndim_iter - ndim_output ? iterstride[iop][idim +
ndim_output] : 1);
+ }
+ ostride.push_back(otmp);
+ rstride.push_back(rtmp);
+ }
+
+ // func: input indices => return cooresponding value
+ auto func = [inputs, oshape, ostride, reduceshape, ndim_iter, rstride,
+ nop](const Array<Var>& input_indices) -> PrimExpr {
+ for (int rdim = 0; rdim < ndim_iter; ++rdim) {
+ if (GetConstInt(reduceshape[rdim]) == 0) {
+ return 0; //
+ }
+ }
+ Array<PrimExpr> ridx = UnravelIndex(0, reduceshape);
+
+ PrimExpr sum = 0;
+ bool rec_flag = false;
+ do {
+ PrimExpr tmp = 1;
+ for (int iop = 0; iop < nop; ++iop) {
+ if (iop != -1) {
+ PrimExpr k = 0;
+
+ for (size_t i = 0; i < input_indices.size(); ++i) {
+ k += input_indices[i] * ostride[iop][i];
+ }
+ for (size_t i = 0; i < ridx.size(); ++i) {
+ k += ridx[i] * rstride[iop][i];
+ }
+ Array<PrimExpr> temp_indices = UnravelIndex(k, inputs[iop]->shape);
+ tmp = tmp * inputs[iop](temp_indices);
+ }
+ }
+ sum += tmp;
+ ridx.Set(ridx.size() - 1, ridx[ridx.size() - 1] + 1);
+ for (int i = static_cast<int>(ridx.size() - 1);
+ (i > 0) && GetConstInt(ridx[i] >= reduceshape[i]); --i) {
+ ridx.Set(i, ridx[i] - reduceshape[i]);
+ ridx.Set(i - 1, ridx[i - 1] + 1);
+ }
+ rec_flag = GetConstInt(ridx[0] < reduceshape[0]);
+ } while (rec_flag);
+ return sum;
+ };
+
+ return compute(oshape, func, name, tag);
+}
+
+} // namespace topi
+} // namespace tvm
+#endif // TVM_TOPI_EINSUM_H_
diff --git a/include/tvm/topi/tags.h b/include/tvm/topi/tags.h
index 3b748ca..c3641ae 100644
--- a/include/tvm/topi/tags.h
+++ b/include/tvm/topi/tags.h
@@ -41,6 +41,7 @@ constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
constexpr auto kDepthwiseConv2dNHWC = "depthwise_conv2d_nhwc";
constexpr auto kDepthwiseConv2dBackInputNHWC =
"depthwise_conv2d_back_input_nhwc";
constexpr auto kDepthwiseConv2dBackWeightNHWC =
"depthwise_conv2d_back_weight_nhwc";
+constexpr auto kEinsum = "einsum";
constexpr auto kGroupConv2d = "group_conv2d";
inline bool is_broadcast(std::string tag) {
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index 873901d..6836f04 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -41,6 +41,7 @@ from .scatter import *
from .scatter_add import *
from .argwhere import *
from .cumsum import *
+from .einsum import *
from . import generic
from . import nn
from . import x86
diff --git a/python/tvm/topi/einsum.py b/python/tvm/topi/einsum.py
new file mode 100644
index 0000000..f1f426e
--- /dev/null
+++ b/python/tvm/topi/einsum.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
+"""Einsum operator"""
+from . import cpp
+
+
+def einsum(subscripts, *operand):
+ """Evaluates the Einstein summation convention on the operands.
+
+ Parameters
+ ----------
+ subscripts : string
+ Specifies the subscripts for summation as comma separated list of
subscript labels.
+ An implicit (classical Einstein summation) calculation is performed
unless the
+ explicit indicator ‘->’ is included as well as subscript labels of the
precise
+ output form.
+
+ a_tuple : tuple of tvm.te.Tensor
+ These are the Tensors for the operation.
+ The only difference of einsum between in tvm and numpy is it needs an
extra brackets
+ for the tensors. For example, topi.einsum("ij, jk -> ik", (A, B)).
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ The calculation based on the Einstein summation convention.
+ """
+
+ return cpp.einsum(subscripts, operand)
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index e1e3988..f71fae3 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
+#include <tvm/topi/einsum.h>
#include <tvm/topi/transform.h>
#include <tvm/topi/utils.h>
@@ -165,6 +166,10 @@ TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs
args, TVMRetValue* rv)
}
});
+TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = einsum(args[0], args[1]);
+});
+
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args,
TVMRetValue* rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3], args[4]);
});
diff --git a/tests/python/topi/python/test_topi_einsum.py
b/tests/python/topi/python/test_topi_einsum.py
new file mode 100644
index 0000000..49e9513
--- /dev/null
+++ b/tests/python/topi/python/test_topi_einsum.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import te
+from tvm import topi
+from tvm.topi.utils import get_const_tuple
+
+
+def with_tvm(lam, *args):
+ """Take numpy arrays as args, convert them to TVM tensors and call `lam`.
+ Result of lambda is converted back to numpy array and returned.
+ """
+ ctx = tvm.cpu(0)
+ pls = [] # placeholders
+ vals_nd = [] # initial values
+ for i, arg in enumerate(args):
+ pls.append(te.placeholder(arg.shape, name="pl" + str(i)))
+ vals_nd.append(tvm.nd.array(arg, ctx))
+
+ out = lam(*pls)
+ out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape),
dtype=out.dtype), ctx)
+ s = te.create_schedule([out.op])
+ m = tvm.build(s, pls + [out], "llvm")
+ m(*(vals_nd + [out_nd]))
+ return out_nd.asnumpy()
+
+
+def verify_einsum(subscripts, shapes):
+ ops = []
+ for shape in shapes:
+ tmp = np.random.uniform(low=-1.0, high=1.0,
size=shape).astype(np.float32)
+ ops.append(tmp)
+
+ c1 = np.einsum(subscripts, *ops)
+
+ if len(ops) == 1:
+ c2 = with_tvm(lambda A: topi.einsum(subscripts, A), *ops)
+ elif len(ops) == 2:
+ c2 = with_tvm(lambda A, B: topi.einsum(subscripts, A, B), *ops)
+ elif len(ops) == 3:
+ c2 = with_tvm(lambda A, B, C: topi.einsum(subscripts, A, B, C), *ops)
+
+ tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
+
+
+def test_einsum():
+ verify_einsum("ii", [(5, 5)])
+ verify_einsum("ii->i", [(5, 5)])
+ verify_einsum("ij->i", [(5, 5)])
+ verify_einsum("...j->...", [(5, 5)])
+ verify_einsum("...j, j", [(5, 5), (5,)])
+ verify_einsum("..., ...", [(), (2, 3)])
+ verify_einsum("ijk, jil->kl", [(3, 4, 5), (4, 3, 2)])
+ verify_einsum("ij, ij -> i", [(1, 4), (2, 4)])
+ verify_einsum("...ij, ...jk -> ...ik", [(1, 4), (4, 2)])
+ verify_einsum("...ij, ...ik -> ...jk", [(1, 1, 1, 4), (1, 1, 1, 3)])
+ verify_einsum("ij,jk->ik", [(2, 3), (3, 4)])
+ verify_einsum("ij,jk,km->im", [(2, 3), (3, 4), (4, 5)])
+
+
+if __name__ == "__main__":
+ test_einsum()
diff --git a/tests/python/unittest/test_te_autodiff.py
b/tests/python/unittest/test_te_autodiff.py
index 6031182..b2f2647 100644
--- a/tests/python/unittest/test_te_autodiff.py
+++ b/tests/python/unittest/test_te_autodiff.py
@@ -170,6 +170,10 @@ def test_basic_operation():
Y = topi.tensordot(A, B, 1)
check_grad(Y, X)
+ X = te.placeholder((3, 3), name="X")
+ Y = topi.einsum("ii->i", (X))
+ check_grad(Y, X)
+
def test_topi():
X = te.placeholder((1, 2, 4, 4), name="X")