vinx13 commented on a change in pull request #9306: URL: https://github.com/apache/tvm/pull/9306#discussion_r734149662
########## File path: tests/python/unittest/test_tvmscript_error_report.py ########## @@ -511,5 +512,77 @@ def render(e): # TODO(Siyuan): block iter errors. + [email protected]_func +def elementwise_not_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 8): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + [email protected]_func +def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in T.serial(0, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +def test_reorder_fail_block(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(l, i) + expected_sub_error_message = ( + " # tir.Block#0\n" + ' with tir.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_inner(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(k, i) + expected_sub_error_message = ( + " for i in tir.serial(0, 128):\n" + " # tir.For#0\n" + " for j in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_outer(): Review comment: ```suggestion def test_fuse_fail_nested_loop_outer(): ``` ########## File path: src/printer/tvmscript_printer.cc ########## @@ -1343,12 +1371,60 @@ Doc TVMScriptPrinter::PrintLoopStack() { return res; } +/*! + * \brief The printer for TVMScript with diagnostic + * \details The printer obtain the precedence of the top-level operation when printing each + * subexpression to decide whether or not parentheses is needed. + */ +class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { + public: + explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc<std::string(Stmt)> annotate) + : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} + + protected: + Doc PrintBlockName(const BlockNode* block_op); + Doc PrintUnderline(const Stmt& stmt, int length); + Doc PrintLoop(const For& loop); +}; + +Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { + Doc doc = PrintOptionalInfo(GetRef<Stmt>(block_op)); Review comment: better to move this to `TVMScriptPrinter` ########## File path: src/tir/schedule/error.cc ########## @@ -24,29 +24,37 @@ namespace tir { String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - os << "ScheduleError: An error occurred in the schedule primitive '" << primitive - << "'.\n\nThe IR is:\n" - << AsTVMScript(mod); + + // get locations of interest Array<ObjectRef> locs = LocationsOfInterest(); + std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual> loc_obj_to_name; int n_locs = locs.size(); - std::vector<String> roi_names; - roi_names.reserve(n_locs); - if (n_locs > 0) { - os << "Regions of interest:\n"; - for (const ObjectRef& obj : locs) { - String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); - os << name << "\n" << obj; - roi_names.emplace_back(std::move(name)); - } - os << "\n"; - } std::string msg = DetailRenderTemplate(); - for (int i = 0; i < n_locs; ++i) { - std::string src = "{" + std::to_string(i) + "}"; - for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { - msg.replace(pos, src.length(), roi_names[i]); + if (n_locs > 0) { + for (int i = 0; i < n_locs; ++i) { + std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i); + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), name); + } + loc_obj_to_name.emplace(locs[i], std::move(name)); } } + + // print IR module + runtime::TypedPackedFunc<std::string(Stmt)> annotate = + runtime::TypedPackedFunc<std::string(Stmt)>( + [&loc_obj_to_name](const Stmt& expr) -> std::string { + auto search = loc_obj_to_name.find(Downcast<ObjectRef>(expr)); + if (search == loc_obj_to_name.end()) return ""; Review comment: nit: (representing iterator) ```suggestion auto it = loc_obj_to_name.find(Downcast<ObjectRef>(expr)); if (it == loc_obj_to_name.end()) return ""; ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
