tqchen commented on a change in pull request #6917:
URL: https://github.com/apache/tvm/pull/6917#discussion_r529696620
##
File path: src/target/source/codegen_params.cc
##
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file codegen_params.cc
+ */
+
+#include "codegen_params.h"
+
+#include
+
+#include
+#include
+#include
+#include
+
+namespace tvm {
+namespace codegen {
+
+namespace {
+class DLManagedTensorDeleter {
+ public:
+ void operator()(DLManagedTensor* ptr) { ptr->deleter(ptr); }
+};
+} // namespace
+
+static constexpr const int kMaxLineLength = 80;
+
+template ::value>>
+void PrintArray(void* data, size_t num_elements, int elements_per_row,
std::string indent_str,
+std::ostream& os) {
+ for (size_t i = 0; i < num_elements; i++) {
+int64_t elem = static_cast(data)[i];
+if (std::is_signed::value) {
+ uint64_t to_print;
+ if (elem < 0) {
+os << "-";
+to_print = -elem;
+ } else {
+os << "+";
+to_print = elem;
+ }
+ os << "0x" << std::setw(sizeof(T) * 8 / 4) <<
static_cast(to_print);
+} else {
+ os << "0x" << std::setw(sizeof(T) * 8 / 4) <<
static_cast(elem);
+}
+if (i < num_elements - 1) {
+ os << ", ";
+}
+if (((i + 1) % elements_per_row) == 0) {
+ os << "\n" << indent_str;
+}
+ }
+}
+
+template ::value>>
+void PrintArray(void* data, size_t num_elements, int one_element_size_bytes,
int elements_per_row,
+std::string indent_str, std::ostream& os) {
+ std::stringstream ss;
+ if (std::is_signed::value) {
+ss.setf(std::ios::hex | std::ios::showbase | std::ios::fixed |
std::ios::scientific,
+std::ios::basefield | std::ios::showbase | std::ios::floatfield);
+ } else {
+ss.setf(std::ios::hex | std::ios::fixed | std::ios::scientific,
+std::ios::basefield | std::ios::showbase | std::ios::floatfield);
+ }
+ for (size_t i = 0; i < num_elements; i++) {
+T elem = static_cast(data)[i];
+if (std::isinf(elem)) {
+ // C99 standard.
+ os << (elem < 0 ? "-" : " ") << std::setw(one_element_size_bytes - 1) <<
"INFINITY";
+} else if (std::isnan(elem)) {
+ // GNU extension, implemenatation-dependent.
+ os << std::setw(one_element_size_bytes) << "NAN";
+} else {
+ ss << elem;
+ os << std::setw(one_element_size_bytes) << ss.str();
+ ss.str("");
+}
+if (i < num_elements - 1) {
+ os << ", ";
+}
+if (((i + 1) % elements_per_row) == 0) {
+ os << "\n" << indent_str;
+}
+ }
+}
+
+void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars,
std::ostream& os) {
+ auto arr_type = arr.DataType();
+ CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating
1-lane parameters; saw "
+<< arr_type.lanes();
+
+ int one_element_size_bytes = (arr_type.bits() / 4) + (2 /* "0x" */) + (2 /*
", " */);
+ if (arr_type.code() == runtime::DataType::TypeCode::kInt) {
+one_element_size_bytes += 1; // sign character
+if (arr_type.bits() > 32) {
+ one_element_size_bytes += 2; // "LL"
+}
+ } else if (arr_type.code() == runtime::DataType::TypeCode::kUInt) {
+if (arr_type.bits() > 32) {
+ one_element_size_bytes += 3; // "ULL"
+}
+ } else if (arr_type.code() == runtime::DataType::TypeCode::kFloat) {
+// Floats and doubles are printed as hex but casted.
+one_element_size_bytes += 1 /* sign */ + 1 /* decimal point */ + 1 /*
exponent sign */;
+if (arr_type.bits() == 64) {
+ one_element_size_bytes += 2; /* 4 decimal digits in exponent, relative
to bits / 4 */
+} else if (arr_type.bits() == 32) {
+ one_element_size_bytes += 1; /* extra decimal digit in exponent,
relative to bits / 4 */
+}
+ }
+
+ int elements_per_row = 16;
+ while (elements_per_row > 1 &&
+ (elements_per_row * one_element_size_bytes) > (kMaxLineLength -
indent_chars)) {
+elements_per_row /= 2;
+ }
+
+ std::string indent_str(indent_chars, ' ');
+ os << indent_str;
+
+ auto shape = arr.Shape();
+ int num_elements = 1;
+ for (auto shape_elem : shape) {
+num_elements *=