marcoabreu closed pull request #9999: R fix URL: https://github.com/apache/incubator-mxnet/pull/9999
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/R-package/R/viz.graph.R b/R-package/R/viz.graph.R index 49f978a2cd9..abc83236bb3 100644 --- a/R-package/R/viz.graph.R +++ b/R-package/R/viz.graph.R @@ -8,7 +8,6 @@ #' @importFrom stringr str_trim #' @importFrom jsonlite fromJSON #' @importFrom DiagrammeR create_graph -#' @importFrom DiagrammeR set_global_graph_attrs #' @importFrom DiagrammeR add_global_graph_attrs #' @importFrom DiagrammeR create_node_df #' @importFrom DiagrammeR create_edge_df @@ -63,93 +62,91 @@ graph.viz <- function(symbol, shape=NULL, direction="TD", type="graph", graph.wi ) } - model_list<- fromJSON(symbol$as.json()) - model_nodes<- model_list$nodes - model_nodes$id<- seq_len(nrow(model_nodes))-1 - model_nodes$level<- model_nodes$ID + model_list <- fromJSON(symbol$as.json()) + model_nodes <- model_list$nodes + model_nodes$id <- seq_len(nrow(model_nodes))-1 + model_nodes$level <- model_nodes$ID # extract IDs from string list tuple_str <- function(str) vapply(str_extract_all(str, "\\d+"), function(x) paste0(x, collapse="X"), character(1)) - + ### substitute op for heads - op_id<- sort(unique(model_list$heads[1,]+1)) - op_null<- which(model_nodes$op=="null") - op_substitute<- intersect(op_id, op_null) - model_nodes$op[op_substitute]<- model_nodes$name[op_substitute] - - model_nodes$color<- apply(model_nodes["op"], 1, get.color) - model_nodes$shape<- apply(model_nodes["op"], 1, get.shape) - - label_paste <- paste0( - model_nodes$op, - "\n", - model_nodes$name, - "\n", - model_nodes$attr$num_hidden %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), - model_nodes$attr$act_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), - model_nodes$attr$pool_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), - model_nodes$attr$kernel %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), - " / ", - model_nodes$attr$stride %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), - ", ", - model_nodes$attr$num_filter %>% str_replace_na() %>% str_replace_all(pattern = "NA", "") - ) %>% + op_id <- sort(unique(model_list$heads[1,]+1)) + op_null <- which(model_nodes$op=="null") + op_substitute <- intersect(op_id, op_null) + model_nodes$op[op_substitute] <- model_nodes$name[op_substitute] + + model_nodes$color <- apply(model_nodes["op"], 1, get.color) + model_nodes$shape <- apply(model_nodes["op"], 1, get.shape) + + label_paste <- paste0(model_nodes$op, + "\n", + model_nodes$name, + "\n", + model_nodes$attr$num_hidden %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), + model_nodes$attr$act_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), + model_nodes$attr$pool_type %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), + model_nodes$attr$kernel %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), + " / ", + model_nodes$attr$stride %>% tuple_str %>% str_replace_na() %>% str_replace_all(pattern = "NA", ""), + ", ", + model_nodes$attr$num_filter %>% str_replace_na() %>% str_replace_all(pattern = "NA", "")) %>% str_replace_all(pattern = "[^[:alnum:]]+$", "") %>% str_trim - model_nodes$label<- label_paste + model_nodes$label <- label_paste id.to.keep <- model_nodes$id[!model_nodes$op=="null"] nodes_df <- model_nodes[model_nodes$id %in% id.to.keep, c("id", "label", "shape", "color")] ### remapping for DiagrammeR convention - nodes_df$id<- nodes_df$id - nodes_df$id_graph<- seq_len(nrow(nodes_df)) - id_dic<- nodes_df$id_graph - names(id_dic)<- as.character(nodes_df$id) - - edges_id<- model_nodes$id[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"] - edges_id<- id_dic[as.character(edges_id)] - edges<- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"] - edges<- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), simplify = FALSE) - names(edges)<- edges_id - - edges_df<- data.frame( - from=unlist(edges), - to=rep(names(edges), time=lengths(edges)), - arrows = "to", - color="black", - from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), - stringsAsFactors=FALSE) - edges_df$from<- id_dic[as.character(edges_df$from)] - - nodes_df_new<- create_node_df(n = nrow(nodes_df), label=nodes_df$label, shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, style="filled", - fillcolor=adjustcolor(nodes_df$color, alpha.f = 1), fontcolor = "black") - edge_df_new<- create_edge_df(from = edges_df$from, to=edges_df$to, color="black", fontcolor = "black") - - if (!is.null(shape)){ + nodes_df$id <- nodes_df$id + nodes_df$id_graph <- seq_len(nrow(nodes_df)) + id_dic <- nodes_df$id_graph + names(id_dic) <- as.character(nodes_df$id) + + edges_id <- model_nodes$id[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"] + edges_id <- id_dic[as.character(edges_id)] + edges <- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & model_nodes$op!="null"] + edges <- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), simplify = FALSE) + names(edges) <- edges_id + + edges_df <- data.frame(from=unlist(edges), + to=rep(names(edges), time=lengths(edges)), + arrows = "to", + color="black", + from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), + stringsAsFactors=FALSE) + edges_df$from <- id_dic[as.character(edges_df$from)] + + nodes_df_new <- create_node_df(n = nrow(nodes_df), label=nodes_df$label, shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, style="filled", + fillcolor=adjustcolor(nodes_df$color, alpha.f = 1), fontcolor = "black") + edge_df_new <- create_edge_df(from = edges_df$from, to=edges_df$to, color="black", fontcolor = "black") + + if (!is.null(shape)) { if (is.list(shape)) { - edges_labels_raw<- symbol$get.internals()$infer.shape(shape)$out.shapes - } else edges_labels_raw<- symbol$get.internals()$infer.shape(list(data=shape))$out.shapes - if (!is.null(edges_labels_raw)){ + edges_labels_raw <- symbol$get.internals()$infer.shape(shape)$out.shapes + } else edges_labels_raw <- symbol$get.internals()$infer.shape(list(data=shape))$out.shapes + if (!is.null(edges_labels_raw)) { edge_label_str <- function(x) paste0(x, collapse="X") - edges_labels_raw<- vapply(edges_labels_raw, edge_label_str, character(1)) - names(edges_labels_raw)[names(edges_labels_raw)=="data"]<- "data_output" - edge_df_new$label<- edges_labels_raw[edges_df$from_name_output] - edge_df_new$rel<- edge_df_new$label + edges_labels_raw <- vapply(edges_labels_raw, edge_label_str, character(1)) + names(edges_labels_raw)[names(edges_labels_raw)=="data"] <- "data_output" + edge_df_new$label <- edges_labels_raw[edges_df$from_name_output] + edge_df_new$rel <- edge_df_new$label } } - graph<- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, directed = TRUE) %>% - set_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>% + graph <- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, directed = TRUE, attr_theme = NULL) %>% + add_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>% add_global_graph_attrs("rankdir", value = direction, attr_type = "graph") if (type=="vis"){ - graph_render<- render_graph(graph = graph, output = "visNetwork", width = graph.width.px, height = graph.height.px) %>% visHierarchicalLayout(direction = direction, sortMethod = "directed") + graph_render <- render_graph(graph = graph, output = "visNetwork", width = graph.width.px, height = graph.height.px) %>% + visHierarchicalLayout(direction = direction, sortMethod = "directed") } else { - graph_render<- render_graph(graph = graph, output = "graph", width = graph.width.px, height = graph.height.px) + graph_render <- render_graph(graph = graph, output = "graph", width = graph.width.px, height = graph.height.px) } return(graph_render) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services