samskalicky commented on a change in pull request #15921: [WIP] dynamic custom
operator support
URL: https://github.com/apache/incubator-mxnet/pull/15921#discussion_r321955086
##########
File path: include/mxnet/lib_api.h
##########
@@ -18,33 +18,627 @@
*/
/*!
- * Copyright (c) 2015 by Contributors
+ * Copyright (c) 2019 by Contributors
* \file lib_api.h
* \brief APIs to interact with libraries
+ * This API specifies function prototypes to
+ * register custom ops for library authors
*/
+
#ifndef MXNET_LIB_API_H_
#define MXNET_LIB_API_H_
+#include <stdint.h>
+#include <vector>
+#include <map>
+#include <string>
+
+#define MX_LIBRARY_VERSION 1
+
/*!
- * \brief Following are the APIs implemented in the external library
+ * \brief External Tensor data types
+ */
+enum MXDType {
+ kFloat32 = 0,
+ kFloat64 = 1,
+ kFloat16 = 2,
+ kUint8 = 3,
+ kInt32 = 4,
+ kInt8 = 5,
+ kInt64 = 6,
+};
+
+enum MXReturnValue {
+ MX_FAIL = 0,
+ MX_SUCCESS = 1,
+};
+
+/*!
+ * \brief External Tensor data structure
+ */
+struct MXTensor {
+ MXTensor() : data(nullptr) {}
+
+ MXTensor(void *data, const std::vector<int64_t> &shape, MXDType dtype)
+ : data{data}, shape{shape}, dtype{dtype} {}
+
+ /*!
+ * \brief helper function to cast data pointer
+ */
+ template<typename data_type>
+ data_type* getData() {
+ return reinterpret_cast<data_type*>(data);
+ }
+
+ void *data; // not owned
+ std::vector<int64_t> shape;
+ MXDType dtype;
+};
+
+/*!
+ * \brief resource malloc function to allocate memory inside fcompute function
+ */
+typedef void* (*xpu_malloc_t)(void*, int);
+
+/*!
+ * \brief Class to provide resource APIs to FCompute
+ */
+class OpResource {
+ public:
+ OpResource(xpu_malloc_t xm, void* _xm) : xpu_malloc(xm), _xpu_malloc(_xm) {}
+
+ /*!
+ * \brief allocate memory controlled by MXNet
+ */
+ void* alloc(int size) {
+ return xpu_malloc(_xpu_malloc, size);
+ }
+ private:
+ xpu_malloc_t xpu_malloc;
+ void* _xpu_malloc;
+};
+
+/*!
+ * \brief StatefulOp wrapper class to pass to backend OpState
+ */
+class CustomStatefulOpWrapper {
+ public:
+ CustomStatefulOpWrapper(void* inst) : instance(inst) {}
+ void* get_instance() { return instance; }
+ private:
+ void* instance;
+};
+
+/*!
+ * \brief An prototype interface class for library author creating stateful op
+ */
+class CustomStatefulOp {
+ public:
+ virtual void Forward() = 0;
+ virtual ~CustomStatefulOp() = 0;
+};
+
+/*!
+ * Custom Operator function templates
+ */
+typedef MXReturnValue (*fcomp_t)(std::map<std::string, std::string>,
+ std::vector<MXTensor>, std::vector<MXTensor>,
+ OpResource res);
+typedef MXReturnValue (*parseAttrs_t)(std::map<std::string, std::string>,
+ int*, int*);
+typedef MXReturnValue (*inferType_t)(std::map<std::string, std::string>,
+ std::vector<int>&, std::vector<int>&);
+typedef MXReturnValue (*inferShape_t)(std::map<std::string, std::string>,
+ std::vector<std::vector<unsigned int>>&,
+ std::vector<std::vector<unsigned int>>&);
+typedef MXReturnValue (*mutateInputs_t)(std::map<std::string, std::string>,
+ std::vector<int>&);
+typedef MXReturnValue (*createOpState_t)(std::map<std::string, std::string>,
+ CustomStatefulOp**);
+typedef MXReturnValue (*fstateful_t)(CustomStatefulOp*, std::vector<MXTensor>,
+ std::vector<MXTensor>);
+
+/*!
+ * \brief Class to hold custom operator registration
+ */
+class CustomOp {
+ public:
+ explicit CustomOp(const char* op_name) : name(op_name), fcompute(nullptr),
+ fgradient(nullptr), parse_attrs(nullptr), infer_type(nullptr),
infer_shape(nullptr),
+ mutate_inputs(nullptr), create_op_state(nullptr), fstateful(nullptr) {}
+ ~CustomOp() {}
+ CustomOp& setForward(fcomp_t fcomp) {
+ fcompute = fcomp;
+ return *this;
+ }
+ CustomOp& setGradient(fcomp_t fcomp) {
+ fgradient = fcomp;
+ return *this;
+ }
+ CustomOp& setParseAttrs(parseAttrs_t func) {
+ parse_attrs = func;
+ return *this;
+ }
+ CustomOp& setInferType(inferType_t func) {
+ infer_type = func;
+ return *this;
+ }
+ CustomOp& setInferShape(inferShape_t func) {
+ infer_shape = func;
+ return *this;
+ }
+ CustomOp& setMutateInputs(mutateInputs_t func) {
+ mutate_inputs = func;
+ return *this;
+ }
+ CustomOp& setCreateOpState(createOpState_t func) {
+ create_op_state = func;
+ return *this;
+ }
+ CustomOp& setForwardStateful(fstateful_t func) {
+ fstateful = func;
+ return *this;
+ }
Review comment:
We need to add some error checking in lib_api.h to prevent users from
registering the wrong combo of functions. Currently we have 2 scenarios:
Basic Op Registration requires:
- Parse Attrs
- Infer Type
- Infer Shape
- Forward
__optional__
- Gradient --> Lets rename to Backward (to match with Forward)
- Mutate inputs
Stateful Op Registration requires:
- Parse Attrs
- Infer Type
- Infer Shape
- Create Op State
- Forward Stateful
__optional__
- Gradient --> Lets rename to Backward (to match with Forward)
- Mutate inputs
We need to enforce that the Basic Op Forward is not registered with Create
Op State or Forward Stateful and vice-versa. This will give users an error when
compiling their operator libraries before loading in MXNet
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services