marcoabreu closed pull request #9999: R fix
URL: https://github.com/apache/incubator-mxnet/pull/9999
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/R-package/R/viz.graph.R b/R-package/R/viz.graph.R
index 49f978a2cd9..abc83236bb3 100644
--- a/R-package/R/viz.graph.R
+++ b/R-package/R/viz.graph.R
@@ -8,7 +8,6 @@
 #' @importFrom stringr str_trim
 #' @importFrom jsonlite fromJSON
 #' @importFrom DiagrammeR create_graph
-#' @importFrom DiagrammeR set_global_graph_attrs 
 #' @importFrom DiagrammeR add_global_graph_attrs
 #' @importFrom DiagrammeR create_node_df
 #' @importFrom DiagrammeR create_edge_df
@@ -63,93 +62,91 @@ graph.viz <- function(symbol, shape=NULL, direction="TD", 
type="graph", graph.wi
     )
   }
   
-  model_list<- fromJSON(symbol$as.json())
-  model_nodes<- model_list$nodes
-  model_nodes$id<- seq_len(nrow(model_nodes))-1
-  model_nodes$level<- model_nodes$ID
+  model_list <- fromJSON(symbol$as.json())
+  model_nodes <- model_list$nodes
+  model_nodes$id <- seq_len(nrow(model_nodes))-1
+  model_nodes$level <- model_nodes$ID
   
   # extract IDs from string list
   tuple_str <- function(str) vapply(str_extract_all(str, "\\d+"),
                                     function(x) paste0(x, collapse="X"),
                                     character(1))
-
+  
   ### substitute op for heads
-  op_id<- sort(unique(model_list$heads[1,]+1))
-  op_null<- which(model_nodes$op=="null")
-  op_substitute<- intersect(op_id, op_null)
-  model_nodes$op[op_substitute]<- model_nodes$name[op_substitute]
-  
-  model_nodes$color<- apply(model_nodes["op"], 1, get.color)
-  model_nodes$shape<- apply(model_nodes["op"], 1, get.shape)
-  
-  label_paste <- paste0(
-    model_nodes$op,
-    "\n",
-    model_nodes$name,
-    "\n",
-    model_nodes$attr$num_hidden %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    model_nodes$attr$act_type %>% str_replace_na() %>% str_replace_all(pattern 
= "NA", ""),
-    model_nodes$attr$pool_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    model_nodes$attr$kernel %>% tuple_str %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    " / ",
-    model_nodes$attr$stride %>% tuple_str %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    ", ",
-    model_nodes$attr$num_filter %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", "")
-  ) %>% 
+  op_id <- sort(unique(model_list$heads[1,]+1))
+  op_null <- which(model_nodes$op=="null")
+  op_substitute <- intersect(op_id, op_null)
+  model_nodes$op[op_substitute] <- model_nodes$name[op_substitute]
+  
+  model_nodes$color <- apply(model_nodes["op"], 1, get.color)
+  model_nodes$shape <- apply(model_nodes["op"], 1, get.shape)
+  
+  label_paste <- paste0(model_nodes$op,
+                        "\n",
+                        model_nodes$name,
+                        "\n",
+                        model_nodes$attr$num_hidden %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$act_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$pool_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$kernel %>% tuple_str %>% 
str_replace_na() %>% str_replace_all(pattern = "NA", ""),
+                        " / ",
+                        model_nodes$attr$stride %>% tuple_str %>% 
str_replace_na() %>% str_replace_all(pattern = "NA", ""),
+                        ", ",
+                        model_nodes$attr$num_filter %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", "")) %>% 
     str_replace_all(pattern = "[^[:alnum:]]+$", "")  %>% 
     str_trim
   
-  model_nodes$label<- label_paste
+  model_nodes$label <- label_paste
   
   id.to.keep <- model_nodes$id[!model_nodes$op=="null"]
   nodes_df <- model_nodes[model_nodes$id %in% id.to.keep, c("id", "label", 
"shape", "color")]
   
   ### remapping for DiagrammeR convention
-  nodes_df$id<- nodes_df$id
-  nodes_df$id_graph<- seq_len(nrow(nodes_df))
-  id_dic<- nodes_df$id_graph
-  names(id_dic)<- as.character(nodes_df$id)
-  
-  edges_id<- model_nodes$id[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
-  edges_id<- id_dic[as.character(edges_id)]
-  edges<- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
-  edges<- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), 
simplify = FALSE)
-  names(edges)<- edges_id
-  
-  edges_df<- data.frame(
-    from=unlist(edges),
-    to=rep(names(edges), time=lengths(edges)),
-    arrows = "to",
-    color="black",
-    from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), 
-    stringsAsFactors=FALSE)
-  edges_df$from<- id_dic[as.character(edges_df$from)]
-  
-  nodes_df_new<- create_node_df(n = nrow(nodes_df), label=nodes_df$label, 
shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, 
style="filled", 
-                                fillcolor=adjustcolor(nodes_df$color, alpha.f 
= 1), fontcolor = "black")
-  edge_df_new<- create_edge_df(from = edges_df$from, to=edges_df$to, 
color="black", fontcolor = "black")
-  
-  if (!is.null(shape)){
+  nodes_df$id <- nodes_df$id
+  nodes_df$id_graph <- seq_len(nrow(nodes_df))
+  id_dic <- nodes_df$id_graph
+  names(id_dic) <- as.character(nodes_df$id)
+  
+  edges_id <- model_nodes$id[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
+  edges_id <- id_dic[as.character(edges_id)]
+  edges <- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
+  edges <- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), 
simplify = FALSE)
+  names(edges) <- edges_id
+  
+  edges_df <- data.frame(from=unlist(edges),
+                         to=rep(names(edges), time=lengths(edges)),
+                         arrows = "to",
+                         color="black",
+                         
from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), 
+                         stringsAsFactors=FALSE)
+  edges_df$from <- id_dic[as.character(edges_df$from)]
+  
+  nodes_df_new <- create_node_df(n = nrow(nodes_df), label=nodes_df$label, 
shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, 
style="filled", 
+                                 fillcolor=adjustcolor(nodes_df$color, alpha.f 
= 1), fontcolor = "black")
+  edge_df_new <- create_edge_df(from = edges_df$from, to=edges_df$to, 
color="black", fontcolor = "black")
+  
+  if (!is.null(shape)) {
     if (is.list(shape)) {
-      edges_labels_raw<- symbol$get.internals()$infer.shape(shape)$out.shapes
-    } else edges_labels_raw<- 
symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
-    if (!is.null(edges_labels_raw)){
+      edges_labels_raw <- symbol$get.internals()$infer.shape(shape)$out.shapes
+    } else edges_labels_raw <- 
symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
+    if (!is.null(edges_labels_raw)) {
       edge_label_str <- function(x) paste0(x, collapse="X")
-      edges_labels_raw<- vapply(edges_labels_raw, edge_label_str, character(1))
-      names(edges_labels_raw)[names(edges_labels_raw)=="data"]<- "data_output"
-      edge_df_new$label<- edges_labels_raw[edges_df$from_name_output]
-      edge_df_new$rel<- edge_df_new$label
+      edges_labels_raw <- vapply(edges_labels_raw, edge_label_str, 
character(1))
+      names(edges_labels_raw)[names(edges_labels_raw)=="data"] <- "data_output"
+      edge_df_new$label <- edges_labels_raw[edges_df$from_name_output]
+      edge_df_new$rel <- edge_df_new$label
     }
   }
   
-  graph<- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, 
directed = TRUE) %>% 
-    set_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>% 
+  graph <- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, 
directed = TRUE, attr_theme = NULL) %>%
+    add_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>%
     add_global_graph_attrs("rankdir", value = direction, attr_type = "graph")
   
   if (type=="vis"){
-    graph_render<- render_graph(graph = graph, output = "visNetwork", width = 
graph.width.px, height = graph.height.px) %>% visHierarchicalLayout(direction = 
direction, sortMethod = "directed")
+    graph_render <- render_graph(graph = graph, output = "visNetwork", width = 
graph.width.px, height = graph.height.px) %>% 
+      visHierarchicalLayout(direction = direction, sortMethod = "directed")
   } else {
-    graph_render<- render_graph(graph = graph, output = "graph", width = 
graph.width.px, height = graph.height.px)
+    graph_render <- render_graph(graph = graph, output = "graph", width = 
graph.width.px, height = graph.height.px)
   }
   
   return(graph_render)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to