This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s1 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit af5d18edcad05e18b43813a43934167bdefd7bd3 Author: tqchen <[email protected]> AuthorDate: Sun Apr 13 19:31:49 2025 -0400 lint --- ffi/include/tvm/ffi/any.h | 4 +- ffi/include/tvm/ffi/object.h | 2 +- ffi/include/tvm/ffi/optional.h | 2 +- ffi/include/tvm/ffi/reflection.h | 14 ++--- ffi/src/ffi/object.cc | 7 +-- include/tvm/ir/expr.h | 1 - include/tvm/meta_schedule/schedule_rule.h | 12 ++-- include/tvm/node/repr_printer.h | 2 +- include/tvm/runtime/logging.h | 6 +- include/tvm/runtime/object.h | 2 + include/tvm/runtime/packed_func.h | 7 ++- include/tvm/script/ir_builder/tir/ir.h | 11 ++-- include/tvm/script/printer/doc.h | 1 + include/tvm/target/target_kind.h | 2 +- include/tvm/tir/schedule/schedule.h | 3 +- include/tvm/topi/reduction.h | 14 +++-- include/tvm/topi/transform.h | 9 +-- src/contrib/msc/core/printer/msc_base_printer.cc | 1 + src/meta_schedule/database/database_utils.cc | 5 +- src/meta_schedule/mutator/mutate_tile_size.cc | 4 +- src/meta_schedule/mutator/mutate_unroll.cc | 7 +-- .../schedule_rule/multi_level_tiling.cc | 6 +- .../multi_level_tiling_wide_vector.cc | 8 ++- .../multi_level_tiling_with_intrin.cc | 10 +-- src/meta_schedule/schedule_rule/schedule_rule.cc | 60 +++++++++--------- src/node/repr_printer.cc | 8 +-- src/node/script_printer.cc | 6 +- src/node/serialization.cc | 3 +- src/node/structural_hash.cc | 4 +- src/relax/backend/contrib/clml/codegen.cc | 3 +- src/relax/backend/contrib/tensorrt/codegen.cc | 3 +- src/relax/op/tensor/create.cc | 3 +- src/relax/op/tensor/create.h | 2 +- src/relax/op/tensor/index.cc | 5 +- src/relax/op/tensor/manipulate.h | 2 +- src/relax/transform/lift_transform_params.cc | 10 +-- src/runtime/disco/message_queue.h | 1 + src/runtime/file_utils.cc | 1 + src/runtime/opencl/opencl_device_api.cc | 3 +- src/runtime/relax_vm/ndarray_cache_support.cc | 4 +- src/script/ir_builder/tir/ir.cc | 7 +-- src/script/printer/relax/call.cc | 5 +- src/support/ffi_testing.cc | 9 ++- src/target/codegen.cc | 1 - src/target/tag.cc | 72 +++++++++++----------- src/target/target.cc | 2 +- src/te/operation/create_primfunc.cc | 2 +- src/te/operation/placeholder_op.cc | 2 +- src/tir/schedule/concrete_schedule.h | 3 +- src/tir/schedule/instruction_traits.h | 5 +- src/tir/schedule/ir_comparator.cc | 5 +- src/tir/schedule/primitive.h | 3 +- src/tir/schedule/primitive/sampling.cc | 4 +- src/tir/schedule/traced_schedule.h | 3 +- src/tir/transforms/default_gpu_schedule.cc | 3 +- src/tir/transforms/primfunc_utils.cc | 4 +- src/topi/broadcast.cc | 4 +- .../hexagon/hexagon_device_api_tests.cc | 6 +- .../cpp-runtime/hexagon/hexagon_user_dma_tests.cc | 4 +- tests/cpp/nested_msg_test.cc | 1 - 60 files changed, 198 insertions(+), 205 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index a9ae2293f3..ccac4980ea 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -115,7 +115,7 @@ class AnyView { return *std::move(opt); } - /* + /* * \brief Shortcut of as Object to cast to a const pointer when T is an Object. * * \tparam T The object type. @@ -132,7 +132,7 @@ class AnyView { TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return data_.type_index != TypeIndex::kTVMFFINone; } - /*! + /*! * \brief Get the type key of the Any * \return The type key of the Any */ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 2237b62073..8a5f535ecb 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -26,8 +26,8 @@ #include <tvm/ffi/base_details.h> #include <tvm/ffi/c_api.h> -#include <string> #include <optional> +#include <string> #include <type_traits> #include <utility> diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index 57c0f5145c..4ac9bc198b 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -56,7 +56,7 @@ class Optional<T, std::enable_if_t<!use_ptr_based_optional_v<T>>> { Optional(const Optional<T>& other) : data_(other.data_) {} Optional(Optional<T>&& other) : data_(std::move(other.data_)) {} Optional(std::optional<T> other) : data_(std::move(other)) {} // NOLINT(*) - Optional(std::nullopt_t) {} // NOLINT(*) + Optional(std::nullopt_t) {} // NOLINT(*) // normal value handling. Optional(T other) // NOLINT(*) : data_(std::move(other)) {} diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h index 230fec9ac0..26e028d726 100644 --- a/ffi/include/tvm/ffi/reflection.h +++ b/ffi/include/tvm/ffi/reflection.h @@ -44,16 +44,12 @@ struct TypeToFieldStaticTypeIndex<T, std::enable_if_t<TypeTraits<T>::convert_ena template <typename T, typename = void> struct TypeToRuntimeTypeIndex { - static int32_t v() { - return TypeToFieldStaticTypeIndex<T>::value; - } + static int32_t v() { return TypeToFieldStaticTypeIndex<T>::value; } }; template <typename T> struct TypeToRuntimeTypeIndex<T, std::enable_if_t<std::is_base_of_v<ObjectRef, T>>> { - static int32_t v() { - return T::ContainerType::RuntimeTypeIndex(); - } + static int32_t v() { return T::ContainerType::RuntimeTypeIndex(); } }; /*! @@ -77,13 +73,13 @@ class ReflectionDef { explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {} template <typename Class, typename T> - ReflectionDef& def_readonly(const char* name, T Class::* field_ptr) { + ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) { RegisterField(name, field_ptr, true); return *this; } template <typename Class, typename T> - ReflectionDef& def_readwrite(const char* name, T Class::* field_ptr) { + ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) { RegisterField(name, field_ptr, false); return *this; } @@ -92,7 +88,7 @@ class ReflectionDef { private: template <typename Class, typename T> - void RegisterField(const char* name, T Class::* field_ptr, bool readonly) { + void RegisterField(const char* name, T Class::*field_ptr, bool readonly) { TVMFFIFieldInfo info; info.name = name; info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value; diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 0b5933ca60..2b5dc1d83a 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -202,7 +202,7 @@ class TypeTable { if (ptr != nullptr && ptr->type_depth != 0) { int parent_index = ptr->type_acenstors[ptr->type_depth - 1]; num_children[parent_index] += num_children[ptr->type_index] + 1; - if (expected_child_slots[ptr->type_index] + 1< ptr->num_slots) { + if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) { expected_child_slots[ptr->type_index] = ptr->num_slots - 1; } expected_child_slots[parent_index] += expected_child_slots[ptr->type_index] + 1; @@ -249,9 +249,8 @@ class TypeTable { ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIOpaquePtr, TypeIndex::kTVMFFIOpaquePtr); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDataType, TypeIndex::kTVMFFIDataType); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDevice, TypeIndex::kTVMFFIDevice); - ReserveBuiltinTypeIndex( - StaticTypeKey::kTVMFFIObjectRValueRef, - TypeIndex::kTVMFFIObjectRValueRef); + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef, + TypeIndex::kTVMFFIObjectRValueRef); } void ReserveBuiltinTypeIndex(const char* type_key, int32_t static_type_index) { diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c02be4e1b0..a3defa592a 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -811,7 +811,6 @@ struct TypeTraits<Bool> : public ObjectRefWithFallbackTraitsBase<Bool, int64_t> static TVM_FFI_INLINE Bool ConvertFallbackValue(int64_t value) { return Bool(value != 0); } }; - // define automatic conversion from bool, int64_t, double to PrimExpr TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(StrictBool value) { return IntImm(DataType::Bool(), value, Span()); diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 6bb7176192..b4da978b56 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -154,10 +154,10 @@ class ScheduleRule : public runtime::ObjectRef { * ignored by default. This function should return True for a block that should be tiled. * \return The schedule rule created */ - TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // - Optional<Array<String>> tile_binds, // - Optional<Integer> max_innermost_factor, // - Optional<Array<Integer>> vector_load_lens, // + TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // + Optional<Array<String>> tile_binds, // + Optional<Integer> max_innermost_factor, // + Optional<Array<Integer>> vector_load_lens, // Optional<Map<String, ffi::Any>> reuse_read, // Optional<Map<String, ffi::Any>> reuse_write, Optional<runtime::PackedFunc> filter_fn = NullOpt); @@ -260,8 +260,8 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // Array<Integer> unroll_max_steps, // bool unroll_explicit); /*! diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 208dfc6347..30bfe8e951 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -43,7 +43,7 @@ class ReprPrinter { /*! \brief The node to be printed. */ TVM_DLL void Print(const ObjectRef& node); - /*! \brief The node to be printed. */ + /*! \brief The node to be printed. */ TVM_DLL void Print(const ffi::Any& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index 5d8780256c..807c9dbf30 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -216,7 +216,8 @@ class InternalError : public Error { if (pos != std::string::npos) { size_t end = pos + 6; size_t begin = pos; - for (; begin != 0 && message[begin - 1] != ' '; --begin) {} + for (; begin != 0 && message[begin - 1] != ' '; --begin) { + } return message.substr(begin, end - begin - 1); } else { return "InternalError"; @@ -228,7 +229,8 @@ class InternalError : public Error { if (pos != std::string::npos) { size_t end = pos + 6; size_t begin = pos; - for (; begin != 0 && message[begin - 1] != ' '; --begin) {} + for (; begin != 0 && message[begin - 1] != ' '; --begin) { + } if (end < message.size() && message[end] == ' ') { end += 1; } diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 61296d80b9..dd7bdacd63 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -29,6 +29,8 @@ #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/container/optional.h> +#include <utility> + namespace tvm { namespace runtime { diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 826df2301b..d90017288b 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -32,6 +32,10 @@ #include <tvm/runtime/module.h> #include <tvm/runtime/ndarray.h> +#include <string> +#include <utility> +#include <vector> + namespace tvm { namespace runtime { @@ -474,10 +478,9 @@ struct ModuleVTableEntryHelper<void (T::*)(Args...)> { } \ } \ } -} // namespace runtime +} // namespace runtime // NOLINT(*) using ffi::Any; using ffi::AnyView; - } // namespace tvm #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index ad86e859d8..5927a0a284 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -150,7 +150,7 @@ void Writes(Array<ObjectRef> buffer_slices); * \brief The block annotation statement. * \param attrs The annotation of the block. */ -void BlockAttrs(Map<String,ffi::Any> attrs); +void BlockAttrs(Map<String, ffi::Any> attrs); /*! * \brief The buffer allocation function. @@ -227,8 +227,7 @@ Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataTy * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame Serial(PrimExpr start, PrimExpr stop, - Optional<Map<String, Any>> annotations = NullOpt); +ForFrame Serial(PrimExpr start, PrimExpr stop, Optional<Map<String, Any>> annotations = NullOpt); /*! * \brief The parallel For statement. * \param start The minimum value of iteration. @@ -236,8 +235,7 @@ ForFrame Serial(PrimExpr start, PrimExpr stop, * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame Parallel(PrimExpr start, PrimExpr stop, - Optional<Map<String, Any>> annotations = NullOpt); +ForFrame Parallel(PrimExpr start, PrimExpr stop, Optional<Map<String, Any>> annotations = NullOpt); /*! * \brief The vectorized For statement. * \param start The minimum value of iteration. @@ -254,8 +252,7 @@ ForFrame Vectorized(PrimExpr start, PrimExpr stop, * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame Unroll(PrimExpr start, PrimExpr stop, - Optional<Map<String, Any>> annotations = NullOpt); +ForFrame Unroll(PrimExpr start, PrimExpr stop, Optional<Map<String, Any>> annotations = NullOpt); /*! * \brief The thread-binding For statement. * \param start The minimum value of iteration. diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 8b8e180cc7..18d7a8194e 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -23,6 +23,7 @@ #include <tvm/node/node.h> #include <tvm/runtime/data_type.h> #include <tvm/runtime/device_api.h> + #include <string> namespace tvm { diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 8856b14604..2ec3fec589 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -414,7 +414,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { .add_attr_option<String>("model") \ .add_attr_option<Array<String>>("libs") \ .add_attr_option<Target>("host") \ - .add_attr_option<int64_t>("from_device") \ + .add_attr_option<int64_t>("from_device") \ .add_attr_option<int64_t>("target_device_type") } // namespace tvm diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 789fc00fd9..f869a4840c 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -224,8 +224,7 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The random variable sampled from candidates */ - virtual ExprRV SampleCategorical(const Array<Integer>& candidates, - const Array<FloatImm>& probs, + virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs, Optional<Integer> decision = NullOpt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 8ed36f6f76..277de68e97 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -202,8 +202,8 @@ inline Tensor CommReduce(const Tensor& data, const Optional<Array<Integer>>& axi * * \return The result tensor. */ -inline Tensor CommReduceIdx(const Tensor& data, const Optional<Array<Integer>>& axis, FCommReduce func, - bool keepdims, bool atleast1d) { +inline Tensor CommReduceIdx(const Tensor& data, const Optional<Array<Integer>>& axis, + FCommReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast<int>(ndim), axis); @@ -497,8 +497,9 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { * * \return A Tensor whose op member is the argmin operation */ -inline Tensor argmin(const Tensor& data, const Optional<Array<Integer>>& axis, bool keepdims = false, - bool atleast1d = false, bool select_last_index = false) { +inline Tensor argmin(const Tensor& data, const Optional<Array<Integer>>& axis, + bool keepdims = false, bool atleast1d = false, + bool select_last_index = false) { auto reducer = MakeArgminReducer(select_last_index); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } @@ -557,8 +558,9 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ -inline Tensor argmax(const Tensor& data, const Optional<Array<Integer>>& axis, bool keepdims = false, - bool atleast1d = false, bool select_last_index = false) { +inline Tensor argmax(const Tensor& data, const Optional<Array<Integer>>& axis, + bool keepdims = false, bool atleast1d = false, + bool select_last_index = false) { auto reducer = MakeArgmaxReducer(select_last_index); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 07e830e78a..59021a18db 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -42,6 +42,7 @@ #include <string> #include <unordered_set> #include <vector> +#include <utility> #include "tvm/ir/expr.h" #include "tvm/runtime/data_type.h" @@ -1756,10 +1757,10 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape); Map<String, ffi::Any> attrs = {{"schedule_rule", String(schedule_rule)}, - // Information about layouts needed for the schedule rule - {"src_layout", String(src_layout)}, - {"dst_layout", String(dst_layout)}, - {"input_shape", src->shape}}; + // Information about layouts needed for the schedule rule + {"src_layout", String(src_layout)}, + {"dst_layout", String(dst_layout)}, + {"input_shape", src->shape}}; return compute( dst_shape, diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index e4fde2a130..954d5246d9 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -21,6 +21,7 @@ * \file src/contrib/msc/core/printer/msc_base_printer.cc */ +#include <utility> #include "msc_base_printer.h" #include "../utils.h" diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index fdc100cecb..e24fd15551 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -67,9 +67,8 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { << kv.first.GetTypeKey(); } } - std::sort(key_values.begin(), key_values.end(), [](const auto& a, const auto& b) { - return a.first < b.first; - }); + std::sort(key_values.begin(), key_values.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); os << "{"; for (int i = 0; i < n; ++i) { const auto& kv = key_values[i]; diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 4376c4db07..201dee69da 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -129,8 +129,8 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst, ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].as<Object>())) { ICHECK_EQ(inst->attrs.size(), 2); - std::vector<double> probs = support::AsVector<FloatImm, double>( - Downcast<Array<FloatImm>>(inst->attrs[1])); + std::vector<double> probs = + support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1])); if (probs.size() == 1) { // Skip mutating the sampling instructions who have only single candidate. continue; diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 851cd7e75d..a3560943db 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -113,10 +113,9 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, const InstructionNode* sample_inst = sample_insts.at(var_rv); ICHECK_EQ(sample_inst->attrs.size(), 2); candidate->inst = GetRef<Instruction>(sample_inst); - candidate->decision = - Downcast<IntImm>(trace->decisions[GetRef<Instruction>(sample_inst)])->value; - candidate->probs = support::AsVector<FloatImm, double>( - Downcast<Array<FloatImm>>(sample_inst->attrs[1])); + candidate->decision = Downcast<IntImm>(trace->decisions[GetRef<Instruction>(sample_inst)])->value; + candidate->probs = + support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(sample_inst->attrs[1])); return true; } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index d694bb986f..b1709cf9a6 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -383,9 +383,9 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, if (!valid_vector_lens.empty()) { int n = valid_vector_lens.size(); double prob = 1.0 / n; - tir::ExprRV vector_load_len = (*sch)->SampleCategorical( - support::AsArray<int, Integer>(valid_vector_lens), - Array<FloatImm>(n, FloatImm(DataType::Float(32), prob))); + tir::ExprRV vector_load_len = + (*sch)->SampleCategorical(support::AsArray<int, Integer>(valid_vector_lens), + Array<FloatImm>(n, FloatImm(DataType::Float(32), prob))); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); } } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 95783541d5..deceaa6f2c 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -112,9 +112,11 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingWideVectorNode } } -ScheduleRule ScheduleRule::MultiLevelTilingWideVector( - String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor, - Optional<Map<String, ffi::Any>> reuse_read, Optional<Map<String, ffi::Any>> reuse_write) { +ScheduleRule ScheduleRule::MultiLevelTilingWideVector(String structure, + Integer vector_length_in_bits, + Optional<Integer> max_innermost_factor, + Optional<Map<String, ffi::Any>> reuse_read, + Optional<Map<String, ffi::Any>> reuse_write) { auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>( structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write); node->vector_length_in_bits = vector_length_in_bits->value; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index f1437f5b35..a8563750c8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -91,10 +91,12 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); }; -ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( - String intrin_name, String structure, Optional<Array<String>> tile_binds, - Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens, - Optional<Map<String, ffi::Any>> reuse_read, Optional<Map<String, ffi::Any>> reuse_write) { +ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(String intrin_name, String structure, + Optional<Array<String>> tile_binds, + Optional<Integer> max_innermost_factor, + Optional<Array<Integer>> vector_load_lens, + Optional<Map<String, ffi::Any>> reuse_read, + Optional<Map<String, ffi::Any>> reuse_write) { ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) << "Provided tensor intrinsic " << intrin_name << " is not registered."; auto node = MultiLevelTilingInitCommon<MultiLevelTilingWithIntrinNode>( diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 1df24a601d..d5bbe8207a 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -74,8 +74,8 @@ Array<ScheduleRule> ScheduleRule::DefaultLLVM() { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, @@ -111,8 +111,8 @@ Array<ScheduleRule> ScheduleRule::DefaultX86(const String& type) { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", /*tile_binds=*/NullOpt, @@ -121,8 +121,8 @@ Array<ScheduleRule> ScheduleRule::DefaultX86(const String& type) { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, @@ -142,12 +142,12 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDA() { /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ Map<String, ffi::Any>{{"req", String("must")}, - {"levels", Array<Integer>{4}}, // - {"scope", String("shared")}}, + {"levels", Array<Integer>{4}}, // + {"scope", String("shared")}}, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("must")}, - {"levels", Array<Integer>{3}}, // - {"scope", String("local")}}), + {"levels", Array<Integer>{3}}, // + {"scope", String("local")}}), ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/true, @@ -245,12 +245,12 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() { /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ Map<String, ffi::Any>{{"req", String("must")}, - {"levels", Array<Integer>{4}}, // - {"scope", String("shared.dyn")}}, + {"levels", Array<Integer>{4}}, // + {"scope", String("shared.dyn")}}, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("must")}, - {"levels", Array<Integer>{2}}, // - {"scope", String("shared.dyn")}}, + {"levels", Array<Integer>{2}}, // + {"scope", String("shared.dyn")}}, /*use_software_pipeline=*/false), // ScheduleRule::MultiLevelTilingTensorCore( /*intrin_groups=*/mma_intrin_groups, @@ -260,12 +260,12 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() { /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ Map<String, ffi::Any>{{"req", String("must")}, - {"levels", Array<Integer>{4}}, // - {"scope", String("shared.dyn")}}, + {"levels", Array<Integer>{4}}, // + {"scope", String("shared.dyn")}}, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("no")}, - {"levels", Array<Integer>{2}}, // - {"scope", String("shared.dyn")}}, + {"levels", Array<Integer>{2}}, // + {"scope", String("shared.dyn")}}, /*use_software_pipeline=*/true) // }; Array<ScheduleRule> append = ScheduleRule::DefaultCUDA(); @@ -292,8 +292,8 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/128, @@ -313,8 +313,8 @@ Array<ScheduleRule> GetARMNeonSpecificRules() { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), }; } @@ -329,8 +329,8 @@ Array<ScheduleRule> GetARMDotprodSpecificRules() { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), ScheduleRule::MultiLevelTilingWithIntrin( /*intrin_name=*/String("dot_4x4_u8u8u32_udot"), /*structure=*/"SSRSRS", @@ -340,8 +340,8 @@ Array<ScheduleRule> GetARMDotprodSpecificRules() { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), ScheduleRule::MultiLevelTilingWithIntrin( /*intrin_name=*/String("dot_4x4_u8u8i32_hdot"), /*structure=*/"SSRSRS", @@ -351,8 +351,8 @@ Array<ScheduleRule> GetARMDotprodSpecificRules() { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), }; } @@ -380,8 +380,8 @@ Array<ScheduleRule> ScheduleRule::DefaultARM(const String& type) { /*reuse_read=*/NullOpt, /*reuse_write=*/ Map<String, ffi::Any>{{"req", String("may")}, - {"levels", Array<Integer>{1, 2}}, - {"scope", String("global")}}), + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/8, /*max_vectorize_extent=*/32, diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 7a0eed7372..22d223fe53 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -22,8 +22,8 @@ * \file node/repr_printer.cc */ #include <tvm/node/repr_printer.h> -#include <tvm/runtime/registry.h> #include <tvm/runtime/device_api.h> +#include <tvm/runtime/registry.h> namespace tvm { @@ -62,9 +62,9 @@ void ReprPrinter::Print(const ffi::Any& node) { case ffi::TypeIndex::kTVMFFIOpaquePtr: { stream << node.operator void*(); break; - case ffi::TypeIndex::kTVMFFIDataType: - stream << node.operator DataType(); - break; + case ffi::TypeIndex::kTVMFFIDataType: + stream << node.operator DataType(); + break; } case ffi::TypeIndex::kTVMFFIDevice: { runtime::operator<<(stream, node.operator Device()); diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index b070f01b96..2f4a9fb9e2 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -96,13 +96,15 @@ PrinterConfig::PrinterConfig(Map<String, Any> config_dict) { n->path_to_underline = Downcast<Optional<Array<ObjectPath>>>(v).value_or(Array<ObjectPath>()); } if (auto v = config_dict.Get("path_to_annotate")) { - n->path_to_annotate = Downcast<Optional<Map<ObjectPath, String>>>(v).value_or(Map<ObjectPath, String>()); + n->path_to_annotate = + Downcast<Optional<Map<ObjectPath, String>>>(v).value_or(Map<ObjectPath, String>()); } if (auto v = config_dict.Get("obj_to_underline")) { n->obj_to_underline = Downcast<Optional<Array<ObjectRef>>>(v).value_or(Array<ObjectRef>()); } if (auto v = config_dict.Get("obj_to_annotate")) { - n->obj_to_annotate = Downcast<Optional<Map<ObjectRef, String>>>(v).value_or(Map<ObjectRef, String>()); + n->obj_to_annotate = + Downcast<Optional<Map<ObjectRef, String>>>(v).value_or(Map<ObjectRef, String>()); } if (auto v = config_dict.Get("syntax_sugar")) { n->syntax_sugar = v.value(); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 3081b23cee..e98e27ea3a 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -482,8 +482,7 @@ class JSONAttrSetter : public AttrVisitor { static Any CreateInitAny(ReflectionVTable* reflection, JSONNode* jnode) { JSONAttrSetter setter; setter.jnode_ = jnode; - if (jnode->type_key == ffi::StaticTypeKey::kTVMFFINone || - jnode->type_key.empty()) { + if (jnode->type_key == ffi::StaticTypeKey::kTVMFFINone || jnode->type_key.empty()) { // empty key type means None in current implementation return Any(); } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 48c92d3117..e81bf66fb9 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -641,8 +641,8 @@ struct MapNodeTrait { // Second, check that we have visited every `rhs` key when iterating over `lhs`. for (const auto& kv : *rhs) { if (!seen_rhs_keys.count(kv.first)) { - equal.RecordMismatchPaths({map_paths->lhs_path->MissingMapEntry(), - map_paths->rhs_path->MapValue(kv.first)}); + equal.RecordMismatchPaths( + {map_paths->lhs_path->MissingMapEntry(), map_paths->rhs_path->MapValue(kv.first)}); return false; } } diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 16941b9f7c..5fd8a923a1 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -305,8 +305,7 @@ void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) * \param functions The extern functions to be compiled via OpenCLML * \return Runtime modules. */ -Array<runtime::Module> OpenCLMLCompiler(Array<Function> functions, - Map<String, Any> /*unused*/, +Array<runtime::Module> OpenCLMLCompiler(Array<Function> functions, Map<String, Any> /*unused*/, Map<Constant, String> constant_names) { Array<runtime::Module> compiled_functions; for (const auto& func : functions) { diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index fee19e53d5..12b99b8252 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -213,8 +213,7 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array<runtime::Module> TensorRTCompiler(Array<Function> functions, - Map<String, ffi::Any> /*unused*/, +Array<runtime::Module> TensorRTCompiler(Array<Function> functions, Map<String, ffi::Any> /*unused*/, Map<Constant, String> constant_names) { Array<runtime::Module> compiled_functions; for (const auto& func : functions) { diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 4f502ffe10..5d556cc356 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -43,7 +43,8 @@ Expr full(Variant<Expr, Array<PrimExpr>> shape, Expr fill_value, Optional<DataTy } else if (const auto* _array = shape.as<ArrayNode>()) { shape_in_expr = ShapeExpr(GetRef<Array<PrimExpr>>(_array)); } else { - LOG(FATAL) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. "; + LOG(FATAL) + << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. "; } ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>(); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 47e00c7dc3..5f770b2a6e 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAX_OP_TENSOR_CREATE_H_ #define TVM_RELAX_OP_TENSOR_CREATE_H_ -#include <tvm/runtime/container/variant.h> #include <tvm/relax/attrs/create.h> +#include <tvm/runtime/container/variant.h> #include "../op_common.h" diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index f54739b6c8..4e62a9ecd7 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -95,9 +95,8 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); } - int axis = attrs->axis.has_value() - ? NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()) - : 0; + int axis = + attrs->axis.has_value() ? NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()) : 0; const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>(); const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>(); if (data_shape == nullptr || indices_shape == nullptr) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index a258114e50..fad634ee9b 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_ #define TVM_RELAX_OP_TENSOR_MANIPULATE_H_ -#include <tvm/runtime/container/variant.h> #include <tvm/relax/attrs/manipulate.h> +#include <tvm/runtime/container/variant.h> #include "../op_common.h" #include "tvm/relax/expr.h" diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 872459251d..180c0be891 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -711,15 +711,15 @@ std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions( auto base_func = mod->functions.Get(gvar.value()); ICHECK(base_func.has_value()) - << "Ill-formed IRModule. " - << "The map from name to GlobalVar found " << gvar.value() - << " for the function name '" << name - << "', but this GlobalVar does not appear in the IRModule"; + << "Ill-formed IRModule. " + << "The map from name to GlobalVar found " << gvar.value() << " for the function name '" + << name << "', but this GlobalVar does not appear in the IRModule"; auto func = base_func.value().as<Function>(); CHECK(func) << "When LiftTransformParams is called with a list of function names, " << "only functions in the list must be relax functions. " - << "However, the function " << name << " is of type " << base_func.value()->GetTypeKey(); + << "However, the function " << name << " is of type " + << base_func.value()->GetTypeKey(); CHECK(func.value()->GetAttr<Integer>(attr::kNumInput)) << "When LiftTransformParams is called with a list of function names, " << "all functions in the list must have the kNumInput ('" << attr::kNumInput diff --git a/src/runtime/disco/message_queue.h b/src/runtime/disco/message_queue.h index 5b53ce40ea..efd3b1735b 100644 --- a/src/runtime/disco/message_queue.h +++ b/src/runtime/disco/message_queue.h @@ -22,6 +22,7 @@ #include <dmlc/io.h> #include <string> +#include <vector> #include "./protocol.h" diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 322b6e7d6d..26974c983d 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -31,6 +31,7 @@ #include <fstream> #include <unordered_map> #include <vector> +#include <utility> namespace tvm { namespace runtime { diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index f068c4daa7..416a1e2653 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -883,8 +883,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { void* CreateView(const Buffer& buffer, ShapeTuple shape, DLDataType type_hint, const std::string& mem_scope) final { OpenCLWorkspace* ws_ = OpenCLWorkspace::Global(); - return ws_->AllocDataSpaceView(buffer.device, buffer.data, shape, type_hint, - String(mem_scope)); + return ws_->AllocDataSpaceView(buffer.device, buffer.data, shape, type_hint, String(mem_scope)); } void FreeView(Device dev, void* data) final { diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index ceb7c378a3..40e73df0a9 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -365,8 +365,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked") names.reserve(args.size()); for (int i = 0; i < args.size(); ++i) { if (!args[i].as<String>()) { - LOG(FATAL) << "ValueError: Expect string as input, but get " - << args[i].GetTypeKey() << " at " << i; + LOG(FATAL) << "ValueError: Expect string as input, but get " << args[i].GetTypeKey() + << " at " << i; } names.push_back(args[i]); } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 1b3de18446..b44d0c29a2 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -325,7 +325,7 @@ Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) { } // namespace axis #define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ - ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, Any>> annotations) { \ + ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, Any>> annotations) { \ PrimExpr min = start; \ PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \ @@ -336,7 +336,7 @@ Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) { ICHECK_EQ(vars.size(), 1); \ ICHECK_EQ(doms.size(), 1); \ return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \ - annotations.value_or(Map<String, Any>())); \ + annotations.value_or(Map<String, Any>())); \ }; \ return ForFrame(n); \ } @@ -473,8 +473,7 @@ AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_s } AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, - Array<PrimExpr> extents, - Optional<Map<String, Any>> annotations) { + Array<PrimExpr> extents, Optional<Map<String, Any>> annotations) { ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>(); n->dtype = dtype; n->extents = extents; diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index e92856d9e8..bd25960912 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -347,9 +347,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (const auto& kv : attrs->dict) { sorted.push_back(kv); } - std::sort(sorted.begin(), sorted.end(), [](const auto& a, const auto& b) { - return a.first < b.first; - }); + std::sort(sorted.begin(), sorted.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); for (const auto& kv : sorted) { kwargs_keys.push_back(kv.first); kwargs_values.push_back( diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index d6f117eb0d..fbd9e3188f 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -207,13 +207,12 @@ TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); -TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray") - .set_body_typed([](Array<Any> arg) -> Any { return arg[0]; }); +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray").set_body_typed([](Array<Any> arg) -> Any { + return arg[0]; +}); TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") - .set_body_typed([](Map<Any, Any> map, Any key) -> Any { - return map[key]; - }); + .set_body_typed([](Map<Any, Any> map, Any key) -> Any { return map[key]; }); TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") .set_body_typed([](Map<Any, Any> map) -> ObjectRef { return map; }); diff --git a/src/target/codegen.cc b/src/target/codegen.cc index f5f1ed546f..78a89872ce 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -352,7 +352,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, std::string blob = PackImportsToBytes(mod); - // Call codegen_blob to generate LLVM module std::string codegen_f_name = "codegen.codegen_blob"; // the codegen function. diff --git a/src/target/tag.cc b/src/target/tag.cc index d17a83e671..0f398d6e19 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -78,10 +78,10 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") {"mattr", Array<String>{"+neon"}}, {"num-cores", 4}, {"host", Map<String, ffi::Any>{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a72")}, - {"mattr", Array<String>{"+neon"}}, - {"num-cores", 4}}}}); + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a72")}, + {"mattr", Array<String>{"+neon"}}, + {"num-cores", 4}}}}); #if TVM_LLVM_VERSION >= 110 TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") @@ -92,9 +92,9 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") {"thread_warp_size", 32}, {"registers_per_block", 65536}, {"host", Map<String, ffi::Any>{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("carmel")}, - {"num-cores", 8}}}}); + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("carmel")}, + {"num-cores", 8}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") .set_config({{"kind", String("cuda")}, @@ -104,9 +104,9 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") {"thread_warp_size", 32}, {"registers_per_block", 65536}, {"host", Map<String, ffi::Any>{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("carmel")}, - {"num-cores", 6}}}}); + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("carmel")}, + {"num-cores", 6}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") .set_config({{"kind", String("cuda")}, @@ -116,9 +116,9 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") {"thread_warp_size", 32}, {"registers_per_block", 65536}, {"host", Map<String, ffi::Any>{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a78")}, - {"num-cores", 8}}}}); + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a78")}, + {"num-cores", 8}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") .set_config({{"kind", String("cuda")}, @@ -128,9 +128,9 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") {"thread_warp_size", 32}, {"registers_per_block", 65536}, {"host", Map<String, ffi::Any>{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a78")}, - {"num-cores", 12}}}}); + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a78")}, + {"num-cores", 12}}}}); #endif // TVM_LLVM_VERSION >= 110 #endif // TVM_LLVM_HAS_AARCH64_TARGET @@ -139,10 +139,10 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") {"kind", String("cuda")}, \ {"keys", Array<String>{"cuda", "gpu"}}, \ {"arch", String(Arch)}, \ - {"max_shared_memory_per_block", SharedMem}, \ - {"max_threads_per_block", 1024}, \ - {"thread_warp_size", 32}, \ - {"registers_per_block", RegPerBlock}, \ + {"max_shared_memory_per_block", SharedMem}, \ + {"max_threads_per_block", 1024}, \ + {"thread_warp_size", 32}, \ + {"registers_per_block", RegPerBlock}, \ }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus @@ -430,25 +430,25 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #undef TVM_REGISTER_TAG_AWS_C5 #if TVM_LLVM_VERSION >= 190 -#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ - TVM_REGISTER_TARGET_TAG(Name).set_config( \ - {{"kind", String("metal")}, \ - {"max_threads_per_block", ThreadsPerBlock}, \ - {"max_shared_memory_per_block", SharedMem}, \ - {"thread_warp_size", WarpSize}, \ +#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ + TVM_REGISTER_TARGET_TAG(Name).set_config( \ + {{"kind", String("metal")}, \ + {"max_threads_per_block", ThreadsPerBlock}, \ + {"max_shared_memory_per_block", SharedMem}, \ + {"thread_warp_size", WarpSize}, \ {"host", Map<String, ffi::Any>{{"kind", String("llvm")}, \ - {"mtriple", String("arm64-apple-macos")}, \ - {"mcpu", String("apple-m4")}}}}); + {"mtriple", String("arm64-apple-macos")}, \ + {"mcpu", String("apple-m4")}}}}); #else -#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ - TVM_REGISTER_TARGET_TAG(Name).set_config( \ - {{"kind", String("metal")}, \ - {"max_threads_per_block", ThreadsPerBlock}, \ - {"max_shared_memory_per_block", SharedMem}, \ - {"thread_warp_size", WarpSize}, \ +#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ + TVM_REGISTER_TARGET_TAG(Name).set_config( \ + {{"kind", String("metal")}, \ + {"max_threads_per_block", ThreadsPerBlock}, \ + {"max_shared_memory_per_block", SharedMem}, \ + {"thread_warp_size", WarpSize}, \ {"host", Map<String, ffi::Any>{{"kind", String("llvm")}, \ - {"mtriple", String("arm64-apple-macos")}, \ - {"mcpu", String("apple-latest")}}}}); + {"mtriple", String("arm64-apple-macos")}, \ + {"mcpu", String("apple-latest")}}}}); #endif #if TVM_LLVM_HAS_AARCH64_TARGET diff --git a/src/target/target.cc b/src/target/target.cc index ad35927561..be96abc8dd 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -970,7 +970,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ffi::Any> config) { } // namespace tvm std::unordered_map<String, ffi::Any> TargetInternal::QueryDevice(int device_id, - const TargetNode* target) { + const TargetNode* target) { std::unordered_map<String, ffi::Any> output; Device device{static_cast<DLDeviceType>(target->GetTargetDeviceType()), device_id}; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 83e998e154..64e8f46852 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -296,7 +296,7 @@ Array<Buffer> GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncI * \returns The block annotation dict. **/ Map<String, ffi::Any> GenerateBlockAnnotations(const te::ComputeOp& compute_op, - CreateFuncInfo* info) { + CreateFuncInfo* info) { Map<String, ffi::Any> annotations; auto mutate_attr = [&info](const ffi::Any& value) -> ffi::Any { if (auto tensor_value = value.as<te::Tensor>()) { diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 6421a86656..53a001a5bb 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -21,9 +21,9 @@ * \brief Placeholder op. * \file placeholder_op.cc */ +#include <tvm/runtime/container/variant.h> #include <tvm/runtime/registry.h> #include <tvm/te/operation.h> -#include <tvm/runtime/container/variant.h> namespace tvm { namespace te { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index cd4a3265f1..c57c6044f5 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -87,8 +87,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array<Integer>& candidates, - const Array<FloatImm>& probs, + ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs, Optional<Integer> decision = NullOpt) override; Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional<Array<Integer>> decision = NullOpt) override; diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 934d0012e2..fe3158894e 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -349,9 +349,8 @@ String UnpackedInstTraits<TTraits>::AsPython(const Array<Any>& inputs, const Arr PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { constexpr size_t kNumArgs = details::NumArgs<method_type>; ICHECK_EQ(args.size(), kNumArgs); - ffi::details::unpack_call<return_type, kNumArgs>( - nullptr, TTraits::UnpackedAsPython, args.data(), - args.size(), rv); + ffi::details::unpack_call<return_type, kNumArgs>(nullptr, TTraits::UnpackedAsPython, + args.data(), args.size(), rv); }); ffi::Any rv; pf.CallPacked(ffi::PackedArgs(packed_args, kNumArgs), &rv); diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 450d4926fe..71b646855d 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -361,7 +361,7 @@ bool TensorizeComparator::CompareAnnotation(const std::pair<String, ffi::Any>& l } // handle expr values if (lhs.second.as<PrimExpr>() && rhs.second.as<PrimExpr>()) { - return VisitExpr(Downcast<PrimExpr>(lhs.second), Downcast<PrimExpr>(rhs.second)); + return VisitExpr(Downcast<PrimExpr>(lhs.second), Downcast<PrimExpr>(rhs.second)); } // handle any other values via any equal if (!ffi::AnyEqual()(lhs.second, rhs.second)) { @@ -389,8 +389,7 @@ bool TensorizeComparator::CompareAnnotationMap(const Map<String, ffi::Any>& lhs, return false; } - auto sort_map = - [](const Map<String, ffi::Any>& map) -> std::vector<std::pair<String, ffi::Any>> { + auto sort_map = [](const Map<String, ffi::Any>& map) -> std::vector<std::pair<String, ffi::Any>> { std::vector<std::pair<String, ffi::Any>> ret(map.begin(), map.end()); sort(ret.begin(), ret.end(), [](const auto& a, const auto& b) { return a.first < b.first; }); return ret; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index d3f199cccd..de8fe7238e 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -55,8 +55,7 @@ std::vector<int32_t> SampleWithoutReplacement( * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array<Integer>& candidates, - const Array<FloatImm>& probs, + const Array<Integer>& candidates, const Array<FloatImm>& probs, Optional<Integer>* decision); /*! * \brief Create a sampling function that does multinomial sampling. diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 95b4131a10..1d3cabee1d 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -460,9 +460,9 @@ struct SampleCategoricalTraits : public UnpackedInstTraits<SampleCategoricalTrai static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 1; - static ExprRV UnpackedApplyToSchedule(Schedule sch, // + static ExprRV UnpackedApplyToSchedule(Schedule sch, // Array<Integer> candidates, // - Array<FloatImm> probs, // + Array<FloatImm> probs, // Optional<Integer> decision) { return sch->SampleCategorical(candidates, probs, decision); } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 0c1bee7252..30eeccfd85 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,8 +47,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array<Integer>& candidates, - const Array<FloatImm>& probs, + ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs, Optional<Integer> decision = NullOpt) final; Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional<Array<Integer>> decision = NullOpt) final; diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 057ed63fab..249003d2ad 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -86,8 +86,7 @@ IRModule MarkScheduled(const IRModule& mod) { for (const auto& [gv, base_func] : mod->functions) { if (const auto* prim_func_node = base_func.as<tir::PrimFuncNode>()) { tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node); - tir::PrimFunc new_prim_func = - WithAttr(std::move(prim_func), tir::attr::kIsScheduled, true); + tir::PrimFunc new_prim_func = WithAttr(std::move(prim_func), tir::attr::kIsScheduled, true); result.Set(gv, new_prim_func); } else { result.Set(gv, base_func); diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 4085ee2293..1cd9bb21a7 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -80,8 +80,8 @@ transform::Pass AnnotateEntryFunc() { bool is_external = base_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined(); if (is_external) { if (auto ptr = base_func.as<PrimFuncNode>()) { - with_annotations->Add( - gvar, WithAttr(GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, true)); + with_annotations->Add(gvar, + WithAttr(GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, true)); } else { has_external_non_primfuncs = true; } diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 0e2963567e..09409077c8 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -34,8 +34,8 @@ using namespace tvm::runtime; #define TOPI_REGISTER_BCAST_OP(OpName, Op) \ TVM_REGISTER_GLOBAL(OpName).set_body([](TVMArgs args, TVMRetValue* rv) { \ - bool lhs_is_tensor = args[0].as<tvm::te::Tensor>().has_value(); \ - bool rhs_is_tensor = args[1].as<tvm::te::Tensor>().has_value(); \ + bool lhs_is_tensor = args[0].as<tvm::te::Tensor>().has_value(); \ + bool rhs_is_tensor = args[1].as<tvm::te::Tensor>().has_value(); \ if (lhs_is_tensor && rhs_is_tensor) { \ *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::te::Tensor()); \ } else if (!lhs_is_tensor && rhs_is_tensor) { \ diff --git a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc index d7d306ff65..c6ce0c72f5 100644 --- a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc @@ -46,9 +46,9 @@ class HexagonDeviceAPITest : public ::testing::Test { int64_t shape2d[2]{256, 256}; int64_t shape3d[3]{256, 256, 256}; Optional<String> default_scope; - Optional<String> invalid_scope{"invalid"}; - Optional<String> global_scope{"global"}; - Optional<String> global_vtcm_scope{"global.vtcm"}; + Optional<String> invalid_scope = String("invalid"); + Optional<String> global_scope = String("global"); + Optional<String> global_vtcm_scope = String("global.vtcm"); }; TEST_F(HexagonDeviceAPITest, global) { CHECK(hexapi != nullptr); } diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc index 9697006bf1..e9c81fa911 100644 --- a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -55,8 +55,8 @@ class HexagonUserDMATest : public ::testing::Test { uint32_t length = 0x4000; // 16KB const bool ENABLE_BYPASS = true; const bool DISABLE_BYPASS = false; - Optional<String> global_scope{"global"}; - Optional<String> global_vtcm_scope{"global.vtcm"}; + Optional<String> global_scope = String("global"); + Optional<String> global_vtcm_scope = String("global.vtcm"); }; TEST_F(HexagonUserDMATest, wait) { diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 4934154bdc..52bfa2cdd3 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -39,7 +39,6 @@ using namespace tvm; using namespace tvm::runtime; using namespace tvm::relax; - TEST(NestedMsg, Basic) { // start with no annotation relax::Var x("x", NullOpt), y("y", NullOpt);
