Make Gremlin.Net graph traversal API type-safe

All steps are now type-safe and the original argument names from
Gremlin-Java are used. However, we currently don't support some Java
types like Comparator. Those were simply replaced by object until we
find a better solution. A problem of this workaround is that some
overloads from Gremlin-Java are not supported in Gremlin.Net as they
would result in the same method signature.
This required to change how Bindings work as Bindings.Of() can no longer
return a Binding object. The implementation for Bindings is now
basically the same as in Gremlin-Java.

This also revealed a bug in the tests that called the WithoutStrategies source 
step with objects of strategies instead of just with their types. However, 
WithoutStrategies still can't work right now as a GraphSON serializer is 
missing for Type.


Project: http://git-wip-us.apache.org/repos/asf/tinkerpop/repo
Commit: http://git-wip-us.apache.org/repos/asf/tinkerpop/commit/05851764
Tree: http://git-wip-us.apache.org/repos/asf/tinkerpop/tree/05851764
Diff: http://git-wip-us.apache.org/repos/asf/tinkerpop/diff/05851764

Branch: refs/heads/TINKERPOP-1752
Commit: 05851764f1e20abb1a82b7d662b4681d602a5774
Parents: c59393f
Author: Florian Hockmann <[email protected]>
Authored: Thu Aug 17 22:57:07 2017 +0200
Committer: florianhockmann <[email protected]>
Committed: Tue Sep 12 17:20:21 2017 +0200

----------------------------------------------------------------------
 gremlin-dotnet/glv/AnonymousTraversal.template  |   11 +-
 gremlin-dotnet/glv/GraphTraversal.template      |   12 +-
 .../glv/GraphTraversalSource.template           |   24 +-
 gremlin-dotnet/pom.xml                          |  160 ++-
 .../Gremlin.Net/Process/Traversal/Bindings.cs   |   35 +-
 .../Gremlin.Net/Process/Traversal/Bytecode.cs   |   80 +-
 .../Process/Traversal/GraphTraversal.cs         | 1123 +++++++++++++++---
 .../Process/Traversal/GraphTraversalSource.cs   |   92 +-
 .../src/Gremlin.Net/Process/Traversal/__.cs     |  910 +++++++++++---
 .../BytecodeGenerationTests.cs                  |   14 +-
 .../BytecodeGeneration/StrategiesTests.cs       |   14 +-
 .../GraphTraversalTests.cs                      |   22 +-
 .../Process/Traversal/BytecodeTests.cs          |  142 ++-
 13 files changed, 2183 insertions(+), 456 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tinkerpop/blob/05851764/gremlin-dotnet/glv/AnonymousTraversal.template
----------------------------------------------------------------------
diff --git a/gremlin-dotnet/glv/AnonymousTraversal.template 
b/gremlin-dotnet/glv/AnonymousTraversal.template
index 9bc7257..6b1de9c 100644
--- a/gremlin-dotnet/glv/AnonymousTraversal.template
+++ b/gremlin-dotnet/glv/AnonymousTraversal.template
@@ -43,9 +43,16 @@ namespace Gremlin.Net.Process.Traversal
         /// <summary>
         ///     Spawns a <see cref="GraphTraversal{SType, EType}" /> and adds 
the <%= method.methodName %> step to that traversal.
         /// </summary>
-        public static GraphTraversal<object, <%= method.t2 %>> <%= 
toCSharpMethodName.call(method.methodName) %><%= method.tParam %>(params 
object[] args)
+        public static GraphTraversal<object, <%= method.t2 %>> <%= 
toCSharpMethodName.call(method.methodName) %><%= method.tParam %>(<%= 
method.parameters %>)
         {
-            return new GraphTraversal<object, object>().<%= 
toCSharpMethodName.call(method.methodName) %><%= method.tParam %>(args);
+        <%  if (method.parameters.contains("params ")) {
+      %>    return <%= method.paramNames.last() %>.Length == 0
+                ? new GraphTraversal<object, <%= method.graphTraversalT2 
%>>().<%= toCSharpMethodName.call(method.methodName) %><%= 
method.callGenericTypeArg %>(<%= method.paramNames.init().join(", ") %>)
+                : new GraphTraversal<object, <%= method.graphTraversalT2 
%>>().<%= toCSharpMethodName.call(method.methodName) %><%= 
method.callGenericTypeArg %>(<%= method.paramNames.join(", ") %>);<%
+            }
+            else {
+      %>    return new GraphTraversal<object, <%= method.graphTraversalT2 
%>>().<%= toCSharpMethodName.call(method.methodName) %><%= 
method.callGenericTypeArg %>(<%= method.paramNames.join(", ") %>);<%
+            } %>            
         }
 <% } %>
     }

http://git-wip-us.apache.org/repos/asf/tinkerpop/blob/05851764/gremlin-dotnet/glv/GraphTraversal.template
----------------------------------------------------------------------
diff --git a/gremlin-dotnet/glv/GraphTraversal.template 
b/gremlin-dotnet/glv/GraphTraversal.template
index 5c3e03e..8d88fcb 100644
--- a/gremlin-dotnet/glv/GraphTraversal.template
+++ b/gremlin-dotnet/glv/GraphTraversal.template
@@ -65,9 +65,17 @@ namespace Gremlin.Net.Process.Traversal
         /// <summary>
         ///     Adds the <%= method.methodName %> step to this <see 
cref="GraphTraversal{SType, EType}" />.
         /// </summary>
-        public GraphTraversal< <%= method.t1 %> , <%= method.t2 %> > <%= 
toCSharpMethodName.call(method.methodName) %><%= method.tParam %> (params 
object[] args)
+        public GraphTraversal< <%= method.t1 %> , <%= method.t2 %> > <%= 
toCSharpMethodName.call(method.methodName) %><%= method.tParam %> (<%= 
method.parameters %>)
         {
-            Bytecode.AddStep("<%= method.methodName %>", args);
+        <%  if (method.parameters.contains("params ")) {
+      %>    var args = new List<object> {<%= method.paramNames.init().join(", 
") %>};
+            args.AddRange(<%= method.paramNames.last() %>);
+            Bytecode.AddStep("<%= method.methodName %>", args.ToArray());<%
+            }
+            else {
+      %>    Bytecode.AddStep("<%= method.methodName %>"<% if 
(method.parameters) out << ', '+ method.paramNames.join(", ") %>);<%
+            }
+        %>
             return Wrap< <%= method.t1 %> , <%= method.t2 %> >(this);
         }
 <% } %>

http://git-wip-us.apache.org/repos/asf/tinkerpop/blob/05851764/gremlin-dotnet/glv/GraphTraversalSource.template
----------------------------------------------------------------------
diff --git a/gremlin-dotnet/glv/GraphTraversalSource.template 
b/gremlin-dotnet/glv/GraphTraversalSource.template
index 0d98433..b67dfd7 100644
--- a/gremlin-dotnet/glv/GraphTraversalSource.template
+++ b/gremlin-dotnet/glv/GraphTraversalSource.template
@@ -72,11 +72,19 @@ namespace Gremlin.Net.Process.Traversal
         }
 
 <% sourceStepMethods.each{ method -> %>
-        public GraphTraversalSource <%= toCSharpMethodName.call(method) 
%>(params object[] args)
+        public GraphTraversalSource <%= 
toCSharpMethodName.call(method.methodName) %>(<%= method.parameters %>)
         {
             var source = new GraphTraversalSource(new 
List<ITraversalStrategy>(TraversalStrategies),
                                                   new Bytecode(Bytecode));
-            source.Bytecode.AddSource("<%= method %>", args);
+            <%  if (method.parameters.contains("params ")) {
+          %>var args = new List<object> {<%= method.paramNames.init().join(", 
") %>};
+            args.AddRange(<%= method.paramNames.last() %>);
+            source.Bytecode.AddSource("<%= method.methodName %>", 
args.ToArray());<%
+            }
+            else {
+          %>source.Bytecode.AddSource("<%= method.methodName %>"<% if 
(method.parameters) out << ', '+ method.paramNames.join(", ") %>);<%
+            }
+        %>
             return source;
         }
 <% } %>
@@ -119,10 +127,18 @@ namespace Gremlin.Net.Process.Traversal
         ///     Spawns a <see cref="GraphTraversal{SType, EType}" /> off this 
graph traversal source and adds the <%= method.methodName %> step to that
         ///     traversal.
         /// </summary>
-        public GraphTraversal< <%= method.typeArguments.join(",") %> > <%= 
toCSharpMethodName.call(method.methodName) %>(params object[] args)
+        public GraphTraversal< <%= method.typeArguments.join(",") %> > <%= 
toCSharpMethodName.call(method.methodName) %>(<%= method.parameters %>)
         {
             var traversal = new GraphTraversal< <%= 
method.typeArguments.join(",") %> >(TraversalStrategies, new 
Bytecode(Bytecode));
-            traversal.Bytecode.AddStep("<%= method.methodName %>", args);
+            <%  if (method.parameters.contains("params ")) {
+          %>var args = new List<object> {<%= method.paramNames.init().join(", 
") %>};
+            args.AddRange(<%= method.paramNames.last() %>);
+            traversal.Bytecode.AddStep("<%= method.methodName %>", 
args.ToArray());<%
+            }
+            else {
+      %>    traversal.Bytecode.AddStep("<%= method.methodName %>"<% if 
(method.parameters) out << ', '+ method.paramNames.join(", ") %>);<%
+            }
+        %>
             return traversal;
         }
 <% } %>

http://git-wip-us.apache.org/repos/asf/tinkerpop/blob/05851764/gremlin-dotnet/pom.xml
----------------------------------------------------------------------
diff --git a/gremlin-dotnet/pom.xml b/gremlin-dotnet/pom.xml
index d2ab17c..5d796bc 100644
--- a/gremlin-dotnet/pom.xml
+++ b/gremlin-dotnet/pom.xml
@@ -85,15 +85,35 @@ import java.lang.reflect.Modifier
 def toCSharpTypeMap = ["Long": "long",
                        "Integer": "int",
                        "String": "string",
+                       "boolean": "bool",
                        "Object": "object",
+                       "String[]": "string[]",
+                       "Object[]": "object[]",
+                       "Class": "Type",
+                       "Class[]": "Type[]",
                        "java.util.Map<java.lang.String, E2>": 
"IDictionary<string, E2>",
                        "java.util.Map<java.lang.String, B>": 
"IDictionary<string, E2>",
                        "java.util.List<E>": "IList<E>",
+                       "java.util.List<A>": "IList<E2>",
                        "java.util.Map<K, V>": "IDictionary<K, V>",
                        "java.util.Collection<E2>": "ICollection<E2>",
                        "java.util.Collection<B>": "ICollection<E2>",
                        "java.util.Map<K, java.lang.Long>": "IDictionary<K, 
long>",
-                       "TraversalMetrics": "E2"]
+                       "TraversalMetrics": "E2",
+                       "Traversal": "ITraversal",
+                       "Traversal[]": "ITraversal[]",
+                       "Predicate": "TraversalPredicate",
+                       "P": "TraversalPredicate",
+                       "TraversalStrategy": "ITraversalStrategy",
+                       "TraversalStrategy[]": "ITraversalStrategy[]",
+                       "Function": "object",
+                       "BiFunction": "object",
+                       "UnaryOperator": "object",
+                       "BinaryOperator": "object",
+                       "Consumer": "object",
+                       "Supplier": "object",
+                       "Comparator": "object",
+                       "VertexProgram": "object"]
 
 def useE2 = ["E2", "E2"];
 def methodsWithSpecificTypes = ["constant": useE2,
@@ -101,11 +121,9 @@ def methodsWithSpecificTypes = ["constant": useE2,
                                 "mean": useE2,
                                 "optional": useE2,
                                 "range": useE2,
-                                "select": ["IDictionary<string, E2>", "E2"],
                                 "sum": useE2,
                                 "tail": useE2,
-                                "tree": ["object"],
-                                "unfold": useE2]
+                                "unfold": useE2]                               
                         
 
 def getCSharpGenericTypeParam = { typeName ->
     def tParam = ""
@@ -131,7 +149,7 @@ def toCSharpType = { name ->
 
 def toCSharpMethodName = { symbol -> (String) 
Character.toUpperCase(symbol.charAt(0)) + symbol.substring(1) }
 
-def getJavaParameterTypeNames = { method ->
+def getJavaGenericTypeParameterTypeNames = { method ->
     def typeArguments = method.genericReturnType.actualTypeArguments;
     return typeArguments.
             collect { (it instanceof Class) ? ((Class)it).simpleName : 
it.typeName }.
@@ -146,6 +164,89 @@ def getJavaParameterTypeNames = { method ->
             }
 }
 
+def getJavaParameterTypeNames = { method ->
+    return method.parameters.
+            collect { param ->
+                param.type.simpleName
+            } 
+}
+
+def toCSharpParamString = { param ->
+    csharpParamTypeName = toCSharpType(param.type.simpleName)
+    "${csharpParamTypeName} ${param.name}"
+    }
+
+def getJavaParamTypeString = { method ->
+    getJavaParameterTypeNames(method).join(",")
+}
+
+def getCSharpParamTypeString = { method ->
+    return method.parameters.
+            collect { param ->
+                toCSharpType(param.type.simpleName)
+            }.join(",")
+}
+
+def getCSharpParamString = { method ->
+    def parameters = method.parameters;
+    if (parameters.length == 0)
+        return ""        
+    def csharpParameters = parameters.
+            init().
+            collect { param ->
+                toCSharpParamString(param)
+            };
+    def lastCSharpParam = "";
+    if (method.isVarArgs())
+        lastCSharpParam += "params ";
+    lastCSharpParam += toCSharpParamString(parameters.last())
+    csharpParameters += lastCSharpParam
+    csharpParamString = csharpParameters.join(", ")
+    csharpParamString
+}
+
+def getParamNames = { parameters ->
+    return parameters.
+        collect { param ->
+            param.name
+        }
+}
+
+def hasMethodNoGenericCounterPartInGraphTraversal = { method ->
+    def parameterTypeNames = getJavaParameterTypeNames(method)
+    if (method.name.equals("fold")) {
+        return parameterTypeNames.size() == 0
+    }
+    if (method.name.equals("limit")) {
+        if (parameterTypeNames.size() == 1) {
+            return parameterTypeNames[0].equals("long")
+        }
+    }
+    if (method.name.equals("range")) {
+        if (parameterTypeNames.size() == 2) {
+            return parameterTypeNames[0].equals("long") && 
parameterTypeNames[1].equals("long")
+        }
+    }
+    if (method.name.equals("tail")) {
+        if (parameterTypeNames.size() == 0) {
+            return true
+        }
+        if (parameterTypeNames.size() == 1) {
+            return parameterTypeNames[0].equals("long")
+        }
+    }
+    return false
+}
+
+def t2withSpecialGraphTraversalt2 = ["IList<E2>": "E2"]
+
+def getGraphTraversalT2ForT2 = { t2 ->
+    if (t2withSpecialGraphTraversalt2.containsKey(t2)) {
+        return t2withSpecialGraphTraversalt2.get(t2)
+    }
+    return t2
+}
+
 def binding = ["pmethods": P.class.getMethods().
                                  findAll { 
Modifier.isStatic(it.getModifiers()) }.
                                  findAll { 
P.class.isAssignableFrom(it.returnType) }.
@@ -160,38 +261,45 @@ def binding = ["pmethods": P.class.getMethods().
                                                     
!it.name.equals(TraversalSource.Symbols.withRemote) &&
                                                     
!it.name.equals(TraversalSource.Symbols.withComputer)
                                         }.
-                                        collect { it.name }.
-                                        unique().
-                                        sort { a, b -> a <=> b },
+                                        sort { a, b -> a.name <=> b.name ?: 
getJavaParamTypeString(a) <=> getJavaParamTypeString(b) }.
+                                        unique { a,b -> a.name <=> b.name ?: 
getCSharpParamTypeString(a) <=> getCSharpParamTypeString(b) }.
+                                        collect { javaMethod ->
+                                            def parameters = 
getCSharpParamString(javaMethod)
+                                            def paramNames = 
getParamNames(javaMethod.parameters)
+                                            return ["methodName": 
javaMethod.name, "parameters":parameters, "paramNames":paramNames]
+                                        },
                "sourceSpawnMethods": GraphTraversalSource.getMethods(). // 
SPAWN STEPS
-                                        findAll { 
GraphTraversal.class.equals(it.returnType) && !it.name.equals('inject')}.
-                                        collect { [methodName: it.name, 
typeArguments: it.genericReturnType.actualTypeArguments.collect{t -> 
((java.lang.Class)t).simpleName}] }.
-                                        unique().
-                                        sort { a, b -> a.methodName <=> 
b.methodName },
+                                        findAll { 
GraphTraversal.class.equals(it.returnType) && !it.name.equals('inject')}.       
                                                                       
+                                        sort { a, b -> a.name <=> b.name ?: 
getJavaParamTypeString(a) <=> getJavaParamTypeString(b) }.
+                                        unique { a,b -> a.name <=> b.name ?: 
getCSharpParamTypeString(a) <=> getCSharpParamTypeString(b) }.
+                                        collect { javaMethod ->
+                                            def typeArguments = 
javaMethod.genericReturnType.actualTypeArguments.collect{t -> 
((java.lang.Class)t).simpleName}
+                                            def parameters = 
getCSharpParamString(javaMethod)
+                                            def paramNames = 
getParamNames(javaMethod.parameters)
+                                            return ["methodName": 
javaMethod.name, "typeArguments": typeArguments, "parameters":parameters, 
"paramNames":paramNames]
+                                        },
                "graphStepMethods": GraphTraversal.getMethods().
                                         findAll { 
GraphTraversal.class.equals(it.returnType) }.
                                         findAll { !it.name.equals("clone") && 
!it.name.equals("iterate") }.
-                                        groupBy { it.name }.
-                                        // Select unique by name, with the 
most amount of parameters
-                                        collect { it.value.sort { a, b -> 
b.parameterCount <=> a.parameterCount }.first() }.
-                                        sort { a, b -> a.name <=> b.name }.
+                                        sort { a, b -> a.name <=> b.name ?: 
getJavaParamTypeString(a) <=> getJavaParamTypeString(b) }.
+                                        unique { a,b -> a.name <=> b.name ?: 
getCSharpParamTypeString(a) <=> getCSharpParamTypeString(b) }.
                                         collect { javaMethod ->
-                                            def typeNames = 
getJavaParameterTypeNames(javaMethod)
+                                            def typeNames = 
getJavaGenericTypeParameterTypeNames(javaMethod)
                                             def t1 = toCSharpType(typeNames[0])
                                             def t2 = toCSharpType(typeNames[1])
                                             def tParam = 
getCSharpGenericTypeParam(t2)
-                                            return ["methodName": 
javaMethod.name, "t1":t1, "t2":t2, "tParam":tParam]
+                                            def parameters = 
getCSharpParamString(javaMethod)
+                                            def paramNames = 
getParamNames(javaMethod.parameters)
+                                            return ["methodName": 
javaMethod.name, "t1":t1, "t2":t2, "tParam":tParam, "parameters":parameters, 
"paramNames":paramNames]
                                         },
                "anonStepMethods": __.class.getMethods().
                                         findAll { 
GraphTraversal.class.equals(it.returnType) }.
                                         findAll { 
Modifier.isStatic(it.getModifiers()) }.
                                         findAll { !it.name.equals("__") && 
!it.name.equals("start") }.
-                                        groupBy { it.name }.
-                                        // Select unique by name, with the 
most amount of parameters
-                                        collect { it.value.sort { a, b -> 
b.parameterCount <=> a.parameterCount }.first() }.
-                                        sort { it.name }.
+                                        sort { a, b -> a.name <=> b.name ?: 
getJavaParamTypeString(a) <=> getJavaParamTypeString(b) }.
+                                        unique { a,b -> a.name <=> b.name ?: 
getCSharpParamTypeString(a) <=> getCSharpParamTypeString(b) }.
                                         collect { javaMethod ->
-                                            def typeNames = 
getJavaParameterTypeNames(javaMethod)
+                                            def typeNames = 
getJavaGenericTypeParameterTypeNames(javaMethod)
                                             def t2 = toCSharpType(typeNames[1])
                                             def tParam = 
getCSharpGenericTypeParam(t2)
                                             def specificTypes = 
methodsWithSpecificTypes.get(javaMethod.name)
@@ -199,7 +307,11 @@ def binding = ["pmethods": P.class.getMethods().
                                                 t2 = specificTypes[0]
                                                 tParam = specificTypes.size() 
> 1 ? "<" + specificTypes[1] + ">" : ""
                                             }
-                                            return ["methodName": 
javaMethod.name, "t2":t2, "tParam":tParam]
+                                            def parameters = 
getCSharpParamString(javaMethod)
+                                            def paramNames = 
getParamNames(javaMethod.parameters)
+                                            def callGenericTypeArg = 
hasMethodNoGenericCounterPartInGraphTraversal(javaMethod) ? "" : tParam
+                                            def graphTraversalT2 = 
getGraphTraversalT2ForT2(t2)
+                                            return ["methodName": 
javaMethod.name, "t2":t2, "tParam":tParam, "parameters":parameters, 
"paramNames":paramNames, "callGenericTypeArg":callGenericTypeArg, 
"graphTraversalT2":graphTraversalT2]
                                         },
                "toCSharpMethodName": toCSharpMethodName]
 

http://git-wip-us.apache.org/repos/asf/tinkerpop/blob/05851764/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bindings.cs
----------------------------------------------------------------------
diff --git a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bindings.cs 
b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bindings.cs
index 985369e..2aa532b 100644
--- a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bindings.cs
+++ b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bindings.cs
@@ -21,6 +21,9 @@
 
 #endregion
 
+using System.Collections.Generic;
+using System.Threading;
+
 namespace Gremlin.Net.Process.Traversal
 {
     /// <summary>
@@ -29,14 +32,42 @@ namespace Gremlin.Net.Process.Traversal
     public class Bindings
     {
         /// <summary>
+        ///     Gets an instance of the <see cref="Bindings" /> class.
+        /// </summary>
+        public static Bindings Instance { get; } = new Bindings();
+
+        private static readonly ThreadLocal<Dictionary<object, string>> 
BoundVariableByValue =
+            new ThreadLocal<Dictionary<object, string>>();
+
+        /// <summary>
         ///     Binds the variable to the specified value.
         /// </summary>
         /// <param name="variable">The variable to bind.</param>
         /// <param name="value">The value to which the variable should be 
bound.</param>
         /// <returns>The bound value.</returns>
-        public Binding Of(string variable, object value)
+        public TV Of<TV>(string variable, TV value)
+        {
+            var dict = BoundVariableByValue.Value;
+            if (dict == null)
+            {
+                dict = new Dictionary<object, string>();
+                BoundVariableByValue.Value = dict;
+            }
+            dict[value] = variable;
+            return value;
+        }
+
+        internal static string GetBoundVariable<TV>(TV value)
+        {
+            var dict = BoundVariableByValue.Value;
+            if (dict == null)
+                return null;
+            return !dict.ContainsKey(value) ? null : dict[value];
+        }
+
+        internal static void Clear()
         {
-            return new Binding(variable, value);
+            BoundVariableByValue.Value?.Clear();
         }
     }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tinkerpop/blob/05851764/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bytecode.cs
----------------------------------------------------------------------
diff --git a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bytecode.cs 
b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bytecode.cs
index b76f395..e09c533 100644
--- a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bytecode.cs
+++ b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/Bytecode.cs
@@ -21,7 +21,10 @@
 
 #endregion
 
+using System;
+using System.Collections;
 using System.Collections.Generic;
+using System.Linq;
 
 namespace Gremlin.Net.Process.Traversal
 {
@@ -35,6 +38,8 @@ namespace Gremlin.Net.Process.Traversal
     /// </remarks>
     public class Bytecode
     {
+        private static readonly object[] EmptyArray = new object[0];
+
         /// <summary>
         ///     Initializes a new instance of the <see cref="Bytecode" /> 
class.
         /// </summary>
@@ -69,7 +74,8 @@ namespace Gremlin.Net.Process.Traversal
         /// <param name="args">The traversal source method arguments.</param>
         public void AddSource(string sourceName, params object[] args)
         {
-            SourceInstructions.Add(new Instruction(sourceName, args));
+            SourceInstructions.Add(new Instruction(sourceName, 
FlattenArguments(args)));
+            Bindings.Clear();
         }
 
         /// <summary>
@@ -79,7 +85,77 @@ namespace Gremlin.Net.Process.Traversal
         /// <param name="args">The traversal method arguments.</param>
         public void AddStep(string stepName, params object[] args)
         {
-            StepInstructions.Add(new Instruction(stepName, args));
+            StepInstructions.Add(new Instruction(stepName, 
FlattenArguments(args)));
+            Bindings.Clear();
+        }
+
+        private object[] FlattenArguments(object[] arguments)
+        {
+            if (arguments.Length == 0)
+                return EmptyArray;
+            var flatArguments = new List<object>();
+            foreach (var arg in arguments)
+            {
+                if (arg is object[] objects)
+                {
+                    flatArguments.AddRange(objects.Select(nestObject => 
ConvertArgument(nestObject, true)));
+                }
+                else
+                {
+                    flatArguments.Add(ConvertArgument(arg, true));
+                }
+            }
+            return flatArguments.ToArray();
+        }
+
+        private object ConvertArgument(object argument, bool searchBindings)
+        {
+            if (searchBindings)
+            {
+                var variable = Bindings.GetBoundVariable(argument);
+                if (variable != null)
+                    return new Binding(variable, ConvertArgument(argument, 
false));
+            }
+            if (IsDictionaryType(argument.GetType()))
+            {
+                var dict = new Dictionary<object, object>();
+                foreach (DictionaryEntry item in (IDictionary)argument)
+                {
+                    dict[ConvertArgument(item.Key, true)] = 
ConvertArgument(item.Value, true);
+                }
+                return dict;
+            }
+            if (IsListType(argument.GetType()))
+            {
+                var list = new List<object>(((IList) argument).Count);
+                list.AddRange(from object item in (IList) argument select 
ConvertArgument(item, true));
+                return list;
+            }
+            if (IsHashSetType(argument.GetType()))
+            {
+                var set = new HashSet<object>();
+                foreach (var item in (IEnumerable)argument)
+                {
+                    set.Add(ConvertArgument(item, true));
+                }
+                return set;
+            }
+            return argument;
+        }
+
+        private bool IsDictionaryType(Type type)
+        {
+            return type.IsConstructedGenericType && 
type.GetGenericTypeDefinition() == typeof(Dictionary<,>);
+        }
+
+        private bool IsListType(Type type)
+        {
+            return type.IsConstructedGenericType && 
type.GetGenericTypeDefinition() == typeof(List<>);
+        }
+
+        private bool IsHashSetType(Type type)
+        {
+            return type.IsConstructedGenericType && 
type.GetGenericTypeDefinition() == typeof(HashSet<>);
         }
     }
 }
\ No newline at end of file

Reply via email to