samskalicky commented on a change in pull request #17270: [WIP] Dynamic custom 
operator GPU support
URL: https://github.com/apache/incubator-mxnet/pull/17270#discussion_r369133681
 
 

 ##########
 File path: include/mxnet/lib_api.h
 ##########
 @@ -740,68 +787,81 @@ class Registry {
 typedef int (*opRegSize_t)(void);
 
 #define MXLIB_OPREGGET_STR "_opRegGet"
-typedef int (*opRegGet_t)(int, const char**, fcomp_t*, fcomp_t*,
-                          parseAttrs_t*, inferType_t*,
-                          inferShape_t*, mutateInputs_t*,
-                          createOpState_t*, int*);
+typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop,
+                          const char*** forward_ctx, fcomp_t** forward_fp, 
int* forward_count,
+                          const char*** backward_ctx, fcomp_t** backward_fp, 
int* backward_count,
+                          const char*** create_op_ctx, createOpState_t** 
create_op_fp, int* create_op_count,
+                          parseAttrs_t* parse, inferType_t* type,
+                          inferShape_t* shape, mutateInputs_t* mutate);
 
 #define MXLIB_OPCALLFREE_STR "_opCallFree"
-typedef int (*opCallFree_t)(void*);
+typedef int (*opCallFree_t)(void* ptr);
 
 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs"
-typedef int (*opCallParseAttrs_t)(parseAttrs_t, const char* const*, const 
char* const*, int,
-                                  int*, int*);
+typedef int (*opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char* const* 
keys,
+                                  const char* const* vals, int num,
+                                  int* num_in, int* num_out);
 
 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape"
-typedef int (*opCallInferShape_t)(inferShape_t, const char* const*, const 
char* const*, int,
-                                  unsigned int**, int*, int,
-                                  unsigned int***, int**, int);
+typedef int (*opCallInferShape_t)(inferShape_t inferShape, const char* const* 
keys,
+                                  const char* const* vals, int num,
+                                  unsigned int** inshapes, int* indims, int 
num_in,
+                                  unsigned int*** outshapes, int** outdims, 
int num_out);
 
 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType"
-typedef int (*opCallInferType_t)(inferType_t, const char* const*, const char* 
const*, int,
-                                  int*, int, int*, int);
+typedef int (*opCallInferType_t)(inferType_t inferType, const char* const* 
keys,
+                                 const char* const* vals, int num,
+                                 int* intypes, int num_in, int* outtypes, int 
num_out);
 
 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
-typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, 
int,
-                             const int64_t**, int*, void**, int*, size_t*, int,
-                             const int64_t**, int*, void**, int*, size_t*, int,
-                             xpu_malloc_t, void*);
+typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys, const 
char* const* vals, int num,
+                             const int64_t** inshapes, int* indims, void** 
indata, int* intypes,
+                             size_t* inIDs, const char** indev_type, int* 
indev_id, int num_in,
+                             const int64_t** outshapes, int* outdims, void** 
outdata, int* outtypes,
+                             size_t* outIDs, const char** outdev_type, int* 
outdev_id, int num_out,
+                             xpu_malloc_t cpu_malloc, void* cpu_alloc, void* 
stream);
 
 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
-typedef int (*opCallMutateInputs_t)(mutateInputs_t, const char* const*, const 
char* const*, int,
-                                    int**, int*);
+typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* 
keys,
+                                    const char* const* vals, int num,
+                                    int** mutate_indices, int* indices_size);
 
 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
-typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, 
const char* const*, int,
-                                     void**);
+typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* 
const* keys,
+                                     const char* const* vals, int num,
+                                     void** state_op);
 
 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
-typedef int (*opCallFStatefulComp_t)(int, void*, const int64_t**, int*, 
void**, int*, size_t*,
-                                     int, const int64_t**, int*, void**, int*, 
size_t*,
-                                     int, xpu_malloc_t, void*);
+typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
+                                     const int64_t** inshapes, int* indims, 
void** indata, int* intypes,
+                                     size_t* inIDs, const char** indev_type, 
int* indev_id, int num_in,
+                                     const int64_t** outshapes, int* outdims, 
void** outdata, int* outtypes,
+                                     size_t* outIDs, const char** outdev_type, 
int* outdev_id, int num_out,
+                                     xpu_malloc_t cpu_malloc, void* cpu_alloc, 
void* stream);
 
 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
 typedef int (*partRegSize_t)(void);
 
 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount"
-typedef int (*partRegGetCount_t)(int, const char**);
+typedef int (*partRegGetCount_t)(int idx, const char** name);
 
 #define MXLIB_PARTREGGET_STR "_partRegGet"
-typedef void (*partRegGet_t)(int, int, const char**, supportedOps_t*,
-                            acceptSubgraph_t*, const char**);
+typedef void (*partRegGet_t)(int part_idx, int stg_idx, const char** strategy, 
supportedOps_t* supportedOps,
+                             acceptSubgraph_t* acceptSubgraph, const char** 
op_name);
 
 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps"
-typedef int (*partCallSupportedOps_t)(supportedOps_t, const char*, int, int *,
-                                      const char* const*, const char* const*, 
int);
+typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char 
*json,
+                                      int num_ids, int *ids, const char* 
const* opt_keys,
+                                      const char* const* opt_vals, int 
num_opts);
+
 #define MXLIB_PARTCALLACCEPTSUBGRAPH_STR "_partCallAcceptSubgraph"
-typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph,
-                                        const char *json, int subgraph_id,
-                                        int *accept, const char* const*,
-                                        const char* const*, int,
-                                        char***, char***, int*);
+typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph, const 
char *json,
+                                        int subgraph_id, int *accept, const 
char* const* opt_keys,
+                                        const char* const* opt_vals, int 
num_opts,
+                                        char*** attr_keys, char*** attr_vals, 
int *num_attrs);
 
 #define MXLIB_INITIALIZE_STR "initialize"
-typedef int (*initialize_t)(int);
+typedef int (*initialize_t)(int version);
 
 Review comment:
   thanks for putting all the names, this makes the code much more 
readable/maintainable

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to