This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch v1.0.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.0.0 by this push:
     new 3df9bf8  R graph.viz fix (#10023)
3df9bf8 is described below

commit 3df9bf802021d5aa67c609c6736acee94aaf3a48
Author: jeremiedb <jeremi...@users.noreply.github.com>
AuthorDate: Wed Mar 7 08:28:23 2018 -0500

    R graph.viz fix (#10023)
    
    * R graph.viz fix
    
    * sub
---
 R-package/R/viz.graph.R | 129 ++++++++++++++++++++++++------------------------
 1 file changed, 64 insertions(+), 65 deletions(-)

diff --git a/R-package/R/viz.graph.R b/R-package/R/viz.graph.R
index 6d13de0..abc8323 100644
--- a/R-package/R/viz.graph.R
+++ b/R-package/R/viz.graph.R
@@ -8,7 +8,6 @@
 #' @importFrom stringr str_trim
 #' @importFrom jsonlite fromJSON
 #' @importFrom DiagrammeR create_graph
-#' @importFrom DiagrammeR set_global_graph_attrs 
 #' @importFrom DiagrammeR add_global_graph_attrs
 #' @importFrom DiagrammeR create_node_df
 #' @importFrom DiagrammeR create_edge_df
@@ -63,91 +62,91 @@ graph.viz <- function(symbol, shape=NULL, direction="TD", 
type="graph", graph.wi
     )
   }
   
-  model_list<- fromJSON(symbol$as.json())
-  model_nodes<- model_list$nodes
-  model_nodes$id<- 1:nrow(model_nodes)-1
-  model_nodes$level<- model_nodes$ID
+  model_list <- fromJSON(symbol$as.json())
+  model_nodes <- model_list$nodes
+  model_nodes$id <- seq_len(nrow(model_nodes))-1
+  model_nodes$level <- model_nodes$ID
   
   # extract IDs from string list
-  tuple_str <- function(str) sapply(str_extract_all(str, "\\d+"), function(x) 
paste0(x, collapse="X"))
+  tuple_str <- function(str) vapply(str_extract_all(str, "\\d+"),
+                                    function(x) paste0(x, collapse="X"),
+                                    character(1))
   
   ### substitute op for heads
-  op_id<- sort(unique(model_list$heads[1,]+1))
-  op_null<- which(model_nodes$op=="null")
-  op_substitute<- intersect(op_id, op_null)
-  model_nodes$op[op_substitute]<- model_nodes$name[op_substitute]
-  
-  model_nodes$color<- apply(model_nodes["op"], 1, get.color)
-  model_nodes$shape<- apply(model_nodes["op"], 1, get.shape)
-  
-  label_paste <- paste0(
-    model_nodes$op,
-    "\n",
-    model_nodes$name,
-    "\n",
-    model_nodes$attr$num_hidden %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    model_nodes$attr$act_type %>% str_replace_na() %>% str_replace_all(pattern 
= "NA", ""),
-    model_nodes$attr$pool_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    model_nodes$attr$kernel %>% tuple_str %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    " / ",
-    model_nodes$attr$stride %>% tuple_str %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
-    ", ",
-    model_nodes$attr$num_filter %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", "")
-  ) %>% 
+  op_id <- sort(unique(model_list$heads[1,]+1))
+  op_null <- which(model_nodes$op=="null")
+  op_substitute <- intersect(op_id, op_null)
+  model_nodes$op[op_substitute] <- model_nodes$name[op_substitute]
+  
+  model_nodes$color <- apply(model_nodes["op"], 1, get.color)
+  model_nodes$shape <- apply(model_nodes["op"], 1, get.shape)
+  
+  label_paste <- paste0(model_nodes$op,
+                        "\n",
+                        model_nodes$name,
+                        "\n",
+                        model_nodes$attr$num_hidden %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$act_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$pool_type %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", ""),
+                        model_nodes$attr$kernel %>% tuple_str %>% 
str_replace_na() %>% str_replace_all(pattern = "NA", ""),
+                        " / ",
+                        model_nodes$attr$stride %>% tuple_str %>% 
str_replace_na() %>% str_replace_all(pattern = "NA", ""),
+                        ", ",
+                        model_nodes$attr$num_filter %>% str_replace_na() %>% 
str_replace_all(pattern = "NA", "")) %>% 
     str_replace_all(pattern = "[^[:alnum:]]+$", "")  %>% 
     str_trim
   
-  model_nodes$label<- label_paste
+  model_nodes$label <- label_paste
   
   id.to.keep <- model_nodes$id[!model_nodes$op=="null"]
   nodes_df <- model_nodes[model_nodes$id %in% id.to.keep, c("id", "label", 
"shape", "color")]
   
   ### remapping for DiagrammeR convention
-  nodes_df$id<- nodes_df$id
-  nodes_df$id_graph<- 1:nrow(nodes_df)
-  id_dic<- nodes_df$id_graph
-  names(id_dic)<- as.character(nodes_df$id)
-  
-  edges_id<- model_nodes$id[!sapply(model_nodes$inputs, length)==0 & 
!model_nodes$op=="null"]
-  edges_id<- id_dic[as.character(edges_id)]
-  edges<- model_nodes$inputs[!sapply(model_nodes$inputs, length)==0 & 
!model_nodes$op=="null"]
-  edges<- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), 
simplify = F)
-  names(edges)<- edges_id
-  
-  edges_df<- data.frame(
-    from=unlist(edges),
-    to=rep(names(edges), time=sapply(edges, length)),
-    arrows = "to",
-    color="black",
-    from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), 
-    stringsAsFactors=F)
-  edges_df$from<- id_dic[as.character(edges_df$from)]
-  
-  nodes_df_new<- create_node_df(n = nrow(nodes_df), label=nodes_df$label, 
shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, 
style="filled", 
-                                fillcolor=adjustcolor(nodes_df$color, alpha.f 
= 1), fontcolor = "black")
-  edge_df_new<- create_edge_df(from = edges_df$from, to=edges_df$to, 
color="black", fontcolor = "black")
-  
-  if (!is.null(shape)){
+  nodes_df$id <- nodes_df$id
+  nodes_df$id_graph <- seq_len(nrow(nodes_df))
+  id_dic <- nodes_df$id_graph
+  names(id_dic) <- as.character(nodes_df$id)
+  
+  edges_id <- model_nodes$id[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
+  edges_id <- id_dic[as.character(edges_id)]
+  edges <- model_nodes$inputs[lengths(model_nodes$inputs)!=0 & 
model_nodes$op!="null"]
+  edges <- sapply(edges, function(x)intersect(as.numeric(x[, 1]), id.to.keep), 
simplify = FALSE)
+  names(edges) <- edges_id
+  
+  edges_df <- data.frame(from=unlist(edges),
+                         to=rep(names(edges), time=lengths(edges)),
+                         arrows = "to",
+                         color="black",
+                         
from_name_output=paste0(model_nodes$name[unlist(edges)+1], "_output"), 
+                         stringsAsFactors=FALSE)
+  edges_df$from <- id_dic[as.character(edges_df$from)]
+  
+  nodes_df_new <- create_node_df(n = nrow(nodes_df), label=nodes_df$label, 
shape=nodes_df$shape, type="base", penwidth=2, color=nodes_df$color, 
style="filled", 
+                                 fillcolor=adjustcolor(nodes_df$color, alpha.f 
= 1), fontcolor = "black")
+  edge_df_new <- create_edge_df(from = edges_df$from, to=edges_df$to, 
color="black", fontcolor = "black")
+  
+  if (!is.null(shape)) {
     if (is.list(shape)) {
-      edges_labels_raw<- symbol$get.internals()$infer.shape(shape)$out.shapes
-    } else edges_labels_raw<- 
symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
-    if (!is.null(edges_labels_raw)){
+      edges_labels_raw <- symbol$get.internals()$infer.shape(shape)$out.shapes
+    } else edges_labels_raw <- 
symbol$get.internals()$infer.shape(list(data=shape))$out.shapes
+    if (!is.null(edges_labels_raw)) {
       edge_label_str <- function(x) paste0(x, collapse="X")
-      edges_labels_raw<- sapply(edges_labels_raw, edge_label_str)
-      names(edges_labels_raw)[names(edges_labels_raw)=="data"]<- "data_output"
-      edge_df_new$label<- edges_labels_raw[edges_df$from_name_output]
-      edge_df_new$rel<- edge_df_new$label
+      edges_labels_raw <- vapply(edges_labels_raw, edge_label_str, 
character(1))
+      names(edges_labels_raw)[names(edges_labels_raw)=="data"] <- "data_output"
+      edge_df_new$label <- edges_labels_raw[edges_df$from_name_output]
+      edge_df_new$rel <- edge_df_new$label
     }
   }
   
-  graph<- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, 
directed = T) %>% 
-    set_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>% 
+  graph <- create_graph(nodes_df = nodes_df_new, edges_df = edge_df_new, 
directed = TRUE, attr_theme = NULL) %>%
+    add_global_graph_attrs("layout", value = "dot", attr_type = "graph") %>%
     add_global_graph_attrs("rankdir", value = direction, attr_type = "graph")
   
   if (type=="vis"){
-    graph_render<- render_graph(graph = graph, output = "visNetwork", width = 
graph.width.px, height = graph.height.px) %>% visHierarchicalLayout(direction = 
direction, sortMethod = "directed")
+    graph_render <- render_graph(graph = graph, output = "visNetwork", width = 
graph.width.px, height = graph.height.px) %>% 
+      visHierarchicalLayout(direction = direction, sortMethod = "directed")
   } else {
-    graph_render<- render_graph(graph = graph, output = "graph", width = 
graph.width.px, height = graph.height.px)
+    graph_render <- render_graph(graph = graph, output = "graph", width = 
graph.width.px, height = graph.height.px)
   }
   
   return(graph_render)

-- 
To stop receiving notification emails like this one, please contact
marcoab...@apache.org.

Reply via email to