marcoabreu closed pull request #10023: R graph.viz fix
URL: https://github.com/apache/incubator-mxnet/pull/10023
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/R-package/R/viz.graph.R b/R-package/R/viz.graph.R
index 6d13de0af1d..abc83236bb3 100644
--- a/R-package/R/viz.graph.R
+++ b/R-package/R/viz.graph.R
@@ -8,7 +8,6 @@
 #' @importFrom stringr str_trim
 #' @importFrom jsonlite fromJSON
 #' @importFrom DiagrammeR create_graph
-#' @importFrom DiagrammeR set_global_graph_attrs 
 #' @importFrom DiagrammeR add_global_graph_attrs
 #' @importFrom DiagrammeR create_node_df
 #' @importFrom DiagrammeR create_edge_df
@@ -63,91 +62,91 @@ graph.viz <- function(symbol, shape=NULL, direction="TD", 
type="graph", graph.wi
     )
   }
   
-  model_list<- fromJSON(symbol$as.json())
-  model_nodes<- model_list$nodes
-  model_nodes$id<- 1:nrow(model_nodes)-1
-  model_nodes$level<- model_nodes$ID
+  model_list <- fromJSON(symbol$as.json())
+  model_nodes <- model_list$nodes
+  model_nodes$id <- seq_len(nrow(model_nodes))-1
+  model_nodes$level <- model_nodes$ID
   
   # extract IDs from string list
-  tuple_str <- function(str) sapply(str_extract_all(str, "\\d+"), function(x) 
paste0(x, collapse="X"))
+  tuple_str <- function(str) vapply(str_extract_all(str, "\\d+"),
+                                    function(x) paste0(x, collapse="X"),
+                                    character(1))
   
   ### substitute op for heads
-  op_id<- sort(unique(model_list$heads[1,]+1))
-  op_null<- which(model_nodes$op=="null")
-  op_substitute<- intersect(op_id, op_null)
-  model_nodes$op[op_substitute]<- model_nodes$name[op_substitute]
-  
-  model_nodes$color<- apply(model_nodes["op"], 1, get.color)
-  model_nodes$shape<- apply(model_nodes["op"], 1, get.shape)
-  
-  label_paste <- paste0(
-    model_nodes$op,
-    "\n",
-    model_nodes$name,
-    "\n",
-    model_nodes$attr$num_hidden %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    model_nodes$attr$act_type %>% str_replace_na() %>% str_replace_all(pattern 
= "NA", ""),
-    model_nodes$attr$pool_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    model_nodes$attr$kernel %>% tuple_str %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    " / ",
-    model_nodes$attr$stride %>% tuple_str %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    ", ",
-    model_nodes$attr$num_filter %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", "")
-  ) %>% 
+  op_id <- sort(unique(model_list$heads[1,]+1))
+  op_null <- which(model_nodes$op=="null")
+  op_substitute <- intersect(op_id, op_null)
+  model_nodes$op[op_substitute] <- model_nodes$name[op_substitute]
+  
+  model_nodes$color <- apply(model_nodes["op"], 1, get.color)
+  model_nodes$shape <- apply(model_nodes["op"], 1, get.shape)
+  
+  label_paste <- paste0(model_nodes$op,
+                        "\n",
+                        model_nodes$name,
+                        "\n",
+                        model_nodes$attr$num_hidden %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$act_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$pool_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$kernel %>% tuple_str %>% 
str_replace_na() %>% str_replace_all(pattern = "NA", ""),
+                        " / ",
+                        model_nodes$attr$stride %>% tuple_str %>% 
str_replace_na() %>% str_replace_all(pattern = "NA", ""),
+                        ", ",
+                        model_nodes$attr$num_filter %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", "")) %>% 
     str_replace_all(pattern = "[^[:alnum:]]+$", "")  %>% 
     str_trim
   
-  model_nodes$label<- label_paste
+  model_nodes$label <- label_paste
   
   id.to.keep <- model_nodes$id[!model_nodes$op=="null"]
   nodes_df <- model_nodes[model_nodes$id %in% id.to.keep, c("id", "label", 
"shape", "color")]
   
   ### remapping for DiagrammeR convention
-  nodes_df$id<- nodes_df$id
-  nodes_df$id_graph<- 1:nrow(nodes_df)
-  id_dic<- nodes_df$id_graph
-  names(id_dic)<- as.character(nodes_df$id)
-  
-  edges_id<- model_nodes$id[!sapply(model_nodes$inputs, length)==0 & 
!model_nodes$op=="null"]
-  edges_id<- id_dic[as.character(edges_id)]
-  edges<- model_nodes$inputs[!sapply(model_nodes$inputs, length)==0 & 
!model_nodes$op=="null"]
-  edges<- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), 
simplify = F)
-  names(edges)<- edges_id
-  
-  edges_df<- data.frame(
-    from=unlist(edges),
-    to=rep(names(edges), time=sapply(edges, length)),
-    arrows = "to",
-    color="black",
-    from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), 
-    stringsAsFactors=F)
-  edges_df$from<- id_dic[as.character(edges_df$from)]
-  
-  nodes_df_new<- create_node_df(n = nrow(nodes_df), label=nodes_df$label, 
shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, 
style="filled", 
-                                fillcolor=adjustcolor(nodes_df$color, alpha.f 
= 1), fontcolor = "black")
-  edge_df_new<- create_edge_df(from = edges_df$from, to=edges_df$to, 
color="black", fontcolor = "black")
-  
-  if (!is.null(shape)){
+  nodes_df$id <- nodes_df$id
+  nodes_df$id_graph <- seq_len(nrow(nodes_df))
+  id_dic <- nodes_df$id_graph
+  names(id_dic) <- as.character(nodes_df$id)
+  
+  edges_id <- model_nodes$id[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
+  edges_id <- id_dic[as.character(edges_id)]
+  edges <- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
+  edges <- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), 
simplify = FALSE)
+  names(edges) <- edges_id
+  
+  edges_df <- data.frame(from=unlist(edges),
+                         to=rep(names(edges), time=lengths(edges)),
+                         arrows = "to",
+                         color="black",
+                         
from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), 
+                         stringsAsFactors=FALSE)
+  edges_df$from <- id_dic[as.character(edges_df$from)]
+  
+  nodes_df_new <- create_node_df(n = nrow(nodes_df), label=nodes_df$label, 
shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, 
style="filled", 
+                                 fillcolor=adjustcolor(nodes_df$color, alpha.f 
= 1), fontcolor = "black")
+  edge_df_new <- create_edge_df(from = edges_df$from, to=edges_df$to, 
color="black", fontcolor = "black")
+  
+  if (!is.null(shape)) {
     if (is.list(shape)) {
-      edges_labels_raw<- symbol$get.internals()$infer.shape(shape)$out.shapes
-    } else edges_labels_raw<- 
symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
-    if (!is.null(edges_labels_raw)){
+      edges_labels_raw <- symbol$get.internals()$infer.shape(shape)$out.shapes
+    } else edges_labels_raw <- 
symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
+    if (!is.null(edges_labels_raw)) {
       edge_label_str <- function(x) paste0(x, collapse="X")
-      edges_labels_raw<- sapply(edges_labels_raw, edge_label_str)
-      names(edges_labels_raw)[names(edges_labels_raw)=="data"]<- "data_output"
-      edge_df_new$label<- edges_labels_raw[edges_df$from_name_output]
-      edge_df_new$rel<- edge_df_new$label
+      edges_labels_raw <- vapply(edges_labels_raw, edge_label_str, 
character(1))
+      names(edges_labels_raw)[names(edges_labels_raw)=="data"] <- "data_output"
+      edge_df_new$label <- edges_labels_raw[edges_df$from_name_output]
+      edge_df_new$rel <- edge_df_new$label
     }
   }
   
-  graph<- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, 
directed = T) %>% 
-    set_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>% 
+  graph <- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, 
directed = TRUE, attr_theme = NULL) %>%
+    add_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>%
     add_global_graph_attrs("rankdir", value = direction, attr_type = "graph")
   
   if (type=="vis"){
-    graph_render<- render_graph(graph = graph, output = "visNetwork", width = 
graph.width.px, height = graph.height.px) %>% visHierarchicalLayout(direction = 
direction, sortMethod = "directed")
+    graph_render <- render_graph(graph = graph, output = "visNetwork", width = 
graph.width.px, height = graph.height.px) %>% 
+      visHierarchicalLayout(direction = direction, sortMethod = "directed")
   } else {
-    graph_render<- render_graph(graph = graph, output = "graph", width = 
graph.width.px, height = graph.height.px)
+    graph_render <- render_graph(graph = graph, output = "graph", width = 
graph.width.px, height = graph.height.px)
   }
   
   return(graph_render)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to