samskalicky commented on a change in pull request #18405: URL: https://github.com/apache/incubator-mxnet/pull/18405#discussion_r450349091
########## File path: python/mxnet/gluon/block.py ########## @@ -1040,41 +1040,62 @@ def _build_cache(self, *args): warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - data_indices = [] - param_indices = [] - self._cached_op_args = [] - for i, name in enumerate(input_names): - if name in data_names: - data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) - else: - param_indices.append(i) - self._cached_op_args.append((False, params[name])) - flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ - self._flags - args, _ = _flatten(args, "input") try: - for is_arg, i in self._cached_op_args: - if not is_arg: - i.data() + for name in input_names: + if name in params: + params[name].data() except DeferredInitializationError: self._deferred_infer_shape(*args) - for is_arg, i in self._cached_op_args: - if not is_arg: - i._finish_deferred_init() + for name in input_names: + if name in params: + params[name]._finish_deferred_init() + arg_dict, aux_dict = dict(), dict() if self._backend: ctx = args[0].context # get list of params in the order of out.list_arguments - arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_arguments()} - aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_auxiliary_states()} + arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_arguments()}) + aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_auxiliary_states()}) # Partition the graph. out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) + #update cached graph with partitioned graph self._cached_graph = data, out + + input_names = out.list_inputs() + data_indices = [] + param_indices = [] + self._cached_op_args = [] + for i, name in enumerate(input_names): + pair = None + if name in data_names: + data_indices.append(i) + pair = (True, data_names[name]) + else: + param_indices.append(i) + if name in params: + param = params[name] + else: + assert self._backend, "Parameter " + name + " is missing from block params" + if name in arg_dict or name: + param_data = arg_dict[name] + elif name in aux_dict: + param_data = aux_dict[name] + else: + raise RuntimeError('Expected inputs missing from arg and aux after partioning. ' Review comment: Is this the case when the backend added a param to the graph but not to the arg/aux dict? If so, can we change the error message to say something like "param <x> was added to the graph but the tensor was not added to args/aux"? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org