d-smirnov commented on a change in pull request #8509:
URL: https://github.com/apache/tvm/pull/8509#discussion_r786289131
##########
File path: include/tvm/tir/stmt.h
##########
@@ -585,6 +585,118 @@ class Allocate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};
+/*!
+ * \brief Describes one parameter that should be linked into the generated
module.
+ *
+ * When parameters are to be linked in with generated code (i.e. on
target_host-compatible
+ * backends), Relay attaches instances of this object to a global TIR
function. Code-generators
+ * use the information contained in this node to include the parameter data in
the generated
+ * module.
+ */
+class LinkedParamNode : public Object {
+ public:
+ /*! \brief Unique numeric identifier used by runtimes to lookup this
parameter. */
+ int64_t id;
+
+ /*! \brief Parameter data which should get linked into the final module. */
+ ::tvm::runtime::NDArray param;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("id", &id);
+ v->Visit("param", ¶m);
+ }
+
+ static constexpr const char* _type_key = "tir.LinkedParam";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
+};
+
+/*!
+ * \brief Managed reference to LinkedParamNode.
+ */
+class LinkedParam : public ObjectRef {
+ public:
+ TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
+};
+
+/*!
+ * \brief Allocate a buffer that can be used in body.
+ */
+class AllocateConstNode : public StmtNode {
+ public:
+ /*! \brief The buffer variable. */
+ Var buffer_var;
+ /*! \brief The optional data associated to the constant.
+ */
+ Optional<runtime::NDArray> data;
+ /*! \brief If the PrimFunc containing the Stmt is added to IRModule,
+ this is an optional index to indicate the index within
+ "Constants" attribute, that is a Array<NDArray> of IRModule.
+ */
+ Optional<Integer> irmod_storage_idx;
+ /*! \brief The type of the buffer. */
+ DataType dtype;
+ /*! \brief The extents of the buffer. */
+ Array<PrimExpr> extents;
+ /*! \brief The body to be executed. */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("buffer_var", &buffer_var);
+ v->Visit("dtype", &dtype);
+ v->Visit("extents", &extents);
+ v->Visit("body", &body);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const
{
+ return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype,
other->dtype) &&
+ equal(extents, other->extents) && equal(data, other->data) &&
equal(body, other->body);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce.DefHash(buffer_var);
+ hash_reduce(dtype);
+ hash_reduce(extents);
+ hash_reduce(body);
+ hash_reduce(data);
+ }
+
+ /*!
+ * \brief If the buffer size is constant, return the size.
+ * Otherwise return 0.
+ * \return The result.
+ */
+ int32_t constant_allocation_size() const { return
constant_allocation_size(extents); }
Review comment:
Looks like it was done like this to be in-line with similarly named
method of AllocateNode. Please clarify, should both nodes be refactored then?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]