This is an automated email from the ASF dual-hosted git repository. Cole-Greer pushed a commit to branch simplePDT in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 002865a63cefa5d632438766b11c180eb3a8d2db Author: Cole Greer <[email protected]> AuthorDate: Wed Jun 24 20:38:54 2026 -0700 Add PrimitivePDT support to gremlin-go GLV Implements PrimitivePDT in the Go GLV, mirroring composite support and applying the review lessons from the Python GLV. - PrimitiveProviderDefinedType struct {Name, Value} in providerDefinedType.go. - primitivePDTType (0xf1): type resolution for *PrimitiveProviderDefinedType, primitivePdtWriter (two fully-qualified Strings), writer-map entry, and readPrimitivePDT + deserializer switch case. - PDTRegistry gains an explicit primitive adapter path (RegisterPrimitiveFuncs[WithType], HydratePrimitive) alongside composite. - gremlin-lang text translation (gremlinlang.go) emits PDT("name","value") for *PrimitiveProviderDefinedType and auto-dehydrates values whose type has a registered primitive adapter (primitive checked before composite) — the client-side text path that was the Python gap, unit-tested here. - Client wiring reuses the existing PDTRegistry path. No GraphSON g:PrimitivePdt read path added (consistent with the Go driver's GraphBinary-based V4 response handling). Tests: unit tests for serializer/deserializer round-trip (incl. opaque-value fidelity: leading zeros, large/non-numeric/empty), gremlin-lang text emission, adapter dehydration with primitive-over-composite precedence, and registry hydration (no-adapter raw, error->raw, nil, nested-in-composite) — all passing. Integration tests (unregistered, registered de/hydration, nested, in-collection) pass against the test server: PASS. tinkerpop-2gy.10 Assisted-by: Kiro:claude-opus-4.8 --- gremlin-go/driver/graphBinaryDeserializer.go | 29 +++++ gremlin-go/driver/graphBinarySerializer.go | 3 + gremlin-go/driver/graphBinarySerializer_test.go | 138 +++++++++++++++++++++-- gremlin-go/driver/gremlinlang.go | 12 +- gremlin-go/driver/gremlinlang_test.go | 86 ++++++++++++++ gremlin-go/driver/pdtRegistry.go | 60 +++++++++- gremlin-go/driver/pdtRegistry_test.go | 65 +++++++++++ gremlin-go/driver/providerDefinedType.go | 22 +++- gremlin-go/driver/providerDefinedType_test.go | 10 ++ gremlin-go/driver/serializer.go | 3 +- gremlin-go/driver/traversal_test.go | 142 +++++++++++++++++++++++- 11 files changed, 552 insertions(+), 18 deletions(-) diff --git a/gremlin-go/driver/graphBinaryDeserializer.go b/gremlin-go/driver/graphBinaryDeserializer.go index 6b06bc6e33..865e25cd53 100644 --- a/gremlin-go/driver/graphBinaryDeserializer.go +++ b/gremlin-go/driver/graphBinaryDeserializer.go @@ -278,6 +278,8 @@ func (d *GraphBinaryDeserializer) readValue(dt dataType, flag byte) (interface{} return d.readEnum(dt) case compositePDTType: return d.readCompositePDT() + case primitivePDTType: + return d.readPrimitivePDT() default: return nil, newError(err0408GetSerializerToReadUnknownTypeError, dt) } @@ -843,6 +845,33 @@ func (d *GraphBinaryDeserializer) readCompositePDT() (interface{}, error) { return pdt, nil } +func (d *GraphBinaryDeserializer) readPrimitivePDT() (interface{}, error) { + nameObj, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + name, ok := nameObj.(string) + if !ok || name == "" { + return nil, fmt.Errorf("PrimitiveProviderDefinedType name must be a non-empty string") + } + valueObj, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + value, ok := valueObj.(string) + if !ok { + return nil, fmt.Errorf("PrimitiveProviderDefinedType value must be a string") + } + pdt := &PrimitiveProviderDefinedType{Name: name, Value: value} + if d.pdtRegistry != nil { + hydrated := d.pdtRegistry.HydratePrimitive(pdt) + if hydrated != pdt { + return hydrated, nil + } + } + return pdt, nil +} + // ReadStatus reads the response status after the EndOfStream marker. // Returns the status code, message, exception string, and any error encountered. // This should be called after ReadFullyQualified() returns an EndOfStream marker. diff --git a/gremlin-go/driver/graphBinarySerializer.go b/gremlin-go/driver/graphBinarySerializer.go index c1ec27b674..30eaa2cdde 100644 --- a/gremlin-go/driver/graphBinarySerializer.go +++ b/gremlin-go/driver/graphBinarySerializer.go @@ -65,6 +65,7 @@ const ( gTypeType dataType = 0x30 durationType dataType = 0x81 compositePDTType dataType = 0xf0 + primitivePDTType dataType = 0xf1 markerType dataType = 0xfd nullType dataType = 0xFE ) @@ -623,6 +624,8 @@ func (serializer *graphBinaryTypeSerializer) getType(val interface{}) (dataType, return byteBuffer, nil case *ProviderDefinedType: return compositePDTType, nil + case *PrimitiveProviderDefinedType: + return primitivePDTType, nil default: switch reflect.TypeOf(val).Kind() { case reflect.Map: diff --git a/gremlin-go/driver/graphBinarySerializer_test.go b/gremlin-go/driver/graphBinarySerializer_test.go index 73b059d5c6..9fc634bf22 100644 --- a/gremlin-go/driver/graphBinarySerializer_test.go +++ b/gremlin-go/driver/graphBinarySerializer_test.go @@ -96,16 +96,16 @@ func TestGraphBinaryV4(t *testing.T) { assert.NotNil(t, err) }) - t.Run("getType returns graphType for *Graph", func(t *testing.T) { - res, err := serializer.getType(NewGraph()) - assert.Nil(t, err) - assert.Equal(t, graphType, res) - }) + t.Run("getType returns graphType for *Graph", func(t *testing.T) { + res, err := serializer.getType(NewGraph()) + assert.Nil(t, err) + assert.Equal(t, graphType, res) + }) - t.Run("getWriter returns graphWriter for graphType", func(t *testing.T) { - _, err := serializer.getWriter(graphType) - assert.Nil(t, err) - }) + t.Run("getWriter returns graphWriter for graphType", func(t *testing.T) { + _, err := serializer.getWriter(graphType) + assert.Nil(t, err) + }) }) t.Run("read-write tests", func(t *testing.T) { @@ -820,3 +820,123 @@ func TestProviderDefinedTypeSerialization(t *testing.T) { assert.Equal(t, "com.example.MyType", pdt.Name) }) } + +func TestPrimitiveProviderDefinedTypeSerialization(t *testing.T) { + serializer := graphBinaryTypeSerializer{newLogHandler(&defaultLogger{}, Error, language.English)} + + t.Run("round-trip simple primitive PDT", func(t *testing.T) { + source := &PrimitiveProviderDefinedType{Name: "x:Uint32", Value: "42"} + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*PrimitiveProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "x:Uint32", pdt.Name) + assert.Equal(t, "42", pdt.Value) + }) + + t.Run("round-trip leading zeros", func(t *testing.T) { + source := &PrimitiveProviderDefinedType{Name: "x:ZipCode", Value: "00123"} + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*PrimitiveProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "00123", pdt.Value) + }) + + t.Run("round-trip large number", func(t *testing.T) { + source := &PrimitiveProviderDefinedType{Name: "x:BigNum", Value: "99999999999999999999"} + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*PrimitiveProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "99999999999999999999", pdt.Value) + }) + + t.Run("round-trip non-numeric value", func(t *testing.T) { + source := &PrimitiveProviderDefinedType{Name: "x:Label", Value: "hello world!"} + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*PrimitiveProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "hello world!", pdt.Value) + }) + + t.Run("round-trip empty value", func(t *testing.T) { + source := &PrimitiveProviderDefinedType{Name: "x:Empty", Value: ""} + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*PrimitiveProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "", pdt.Value) + }) + + t.Run("auto-hydrate with registry", func(t *testing.T) { + registry := NewPDTRegistry() + registry.RegisterPrimitiveFuncs("x:Uint32", + func(s string) (interface{}, error) { + return "hydrated:" + s, nil + }, nil) + + source := &PrimitiveProviderDefinedType{Name: "x:Uint32", Value: "42"} + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializerWithRegistry(bytes.NewReader(buf.Bytes()), registry) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + assert.Equal(t, "hydrated:42", result) + }) + + t.Run("no hydration without registry", func(t *testing.T) { + source := &PrimitiveProviderDefinedType{Name: "x:Uint32", Value: "42"} + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*PrimitiveProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "x:Uint32", pdt.Name) + assert.Equal(t, "42", pdt.Value) + }) + + t.Run("empty name produces error", func(t *testing.T) { + data := []byte{ + 0xf1, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // fq string, length 0 + 0x03, 0x00, 0x00, 0x00, 0x00, 0x02, 0x34, 0x32, // fq string "42" + } + d := NewGraphBinaryDeserializer(bytes.NewReader(data)) + _, err := d.ReadFullyQualified() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "non-empty") + }) +} diff --git a/gremlin-go/driver/gremlinlang.go b/gremlin-go/driver/gremlinlang.go index a08c6a7d31..792350fb16 100644 --- a/gremlin-go/driver/gremlinlang.go +++ b/gremlin-go/driver/gremlinlang.go @@ -136,7 +136,6 @@ func escapeString(s string) string { return sb.String() } - func (gl *GremlinLang) argAsString(arg interface{}) (string, error) { if arg == nil { return "null", nil @@ -215,6 +214,8 @@ func (gl *GremlinLang) argAsString(arg interface{}) (string, error) { return "", err } return fmt.Sprintf("PDT(\"%s\",%s)", escapeString(v.Name), mapStr), nil + case *PrimitiveProviderDefinedType: + return fmt.Sprintf("PDT(\"%s\",\"%s\")", escapeString(v.Name), escapeString(v.Value)), nil case *Vertex: return gl.argAsString(v.Id) case textP: @@ -303,6 +304,15 @@ func (gl *GremlinLang) argAsString(arg interface{}) (string, error) { // over any reflection/struct-based fallback, allowing explicit adapters to override // default behavior for a given Go type. if gl.pdtRegistry != nil { + // Check primitive adapter before composite (mandatory per Python review lesson). + primitiveAdapter := gl.pdtRegistry.GetPrimitiveAdapterByType(reflect.TypeOf(arg)) + if primitiveAdapter != nil && primitiveAdapter.ToString != nil { + s, err := primitiveAdapter.ToString(arg) + if err == nil { + pdt := &PrimitiveProviderDefinedType{Name: primitiveAdapter.TypeName, Value: s} + return gl.argAsString(pdt) + } + } adapter := gl.pdtRegistry.GetAdapterByType(reflect.TypeOf(arg)) if adapter != nil && adapter.ToFields != nil { fields, err := adapter.ToFields(arg) diff --git a/gremlin-go/driver/gremlinlang_test.go b/gremlin-go/driver/gremlinlang_test.go index c025cc6d63..760dd2919f 100644 --- a/gremlin-go/driver/gremlinlang_test.go +++ b/gremlin-go/driver/gremlinlang_test.go @@ -975,3 +975,89 @@ func TestPDT_GremlinLang_NestedRegisteredInUnregisteredOuter(t *testing.T) { t.Errorf("nested dehydration: got %v, expected %v", gremlin, expected) } } + +func Test_PrimitivePDT_GremlinLang(t *testing.T) { + t.Run("basic primitive PDT", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &PrimitiveProviderDefinedType{Name: "x:Uint32", Value: "42"} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("x:Uint32","42"))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("primitive PDT with special chars in value", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &PrimitiveProviderDefinedType{Name: "x:Label", Value: `say"hello"`} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("x:Label","say\"hello\""))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("primitive PDT leading zeros preserved", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &PrimitiveProviderDefinedType{Name: "x:ZipCode", Value: "00123"} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("x:ZipCode","00123"))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("primitive PDT empty value", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &PrimitiveProviderDefinedType{Name: "x:Empty", Value: ""} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("x:Empty",""))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) +} + +// primitiveAdapterUint32 is a test type for primitive PDT dehydration. +type primitiveAdapterUint32 uint32 + +func Test_PrimitivePDT_AdapterDehydration(t *testing.T) { + registry := NewPDTRegistry() + registry.RegisterPrimitiveFuncsWithType("x:Uint32", reflect.TypeOf(primitiveAdapterUint32(0)), + func(s string) (interface{}, error) { + return primitiveAdapterUint32(42), nil + }, + func(obj interface{}) (string, error) { + return fmt.Sprintf("%d", obj.(primitiveAdapterUint32)), nil + }) + + g := NewGraphTraversalSource(nil, nil) + g.GetGremlinLang().pdtRegistry = registry + + gremlin := g.Inject(primitiveAdapterUint32(99)).GremlinLang.GetGremlin() + expected := `g.inject(PDT("x:Uint32","99"))` + if gremlin != expected { + t.Errorf("primitive adapter dehydration: got %v, expected %v", gremlin, expected) + } +} + +func Test_PrimitivePDT_PrimitiveAdapterTakesPrecedenceOverComposite(t *testing.T) { + type dualType struct{ V int } + registry := NewPDTRegistry() + // Register both primitive and composite for the same Go type — primitive must win. + registry.RegisterPrimitiveFuncsWithType("x:Dual", reflect.TypeOf(dualType{}), + func(s string) (interface{}, error) { return dualType{V: 1}, nil }, + func(obj interface{}) (string, error) { return "prim", nil }) + registry.RegisterFuncsWithType("x:Dual", reflect.TypeOf(dualType{}), + nil, + func(obj interface{}) (map[string]interface{}, error) { return map[string]interface{}{"v": 1}, nil }) + + g := NewGraphTraversalSource(nil, nil) + g.GetGremlinLang().pdtRegistry = registry + + gremlin := g.Inject(dualType{V: 1}).GremlinLang.GetGremlin() + expected := `g.inject(PDT("x:Dual","prim"))` + if gremlin != expected { + t.Errorf("primitive adapter should take precedence over composite: got %v, expected %v", gremlin, expected) + } +} diff --git a/gremlin-go/driver/pdtRegistry.go b/gremlin-go/driver/pdtRegistry.go index 7b78f4f2f0..6af7bfcf6b 100644 --- a/gremlin-go/driver/pdtRegistry.go +++ b/gremlin-go/driver/pdtRegistry.go @@ -23,20 +23,34 @@ import "reflect" // PDTAdapter defines how to hydrate/dehydrate a provider-defined type. type PDTAdapter struct { - TypeName string + TypeName string FromFields func(map[string]interface{}) (interface{}, error) ToFields func(interface{}) (map[string]interface{}, error) } +// PrimitivePDTAdapter defines how to hydrate/dehydrate a primitive provider-defined type. +type PrimitivePDTAdapter struct { + TypeName string + FromString func(string) (interface{}, error) + ToString func(interface{}) (string, error) +} + // PDTRegistry maps type names to their hydration adapters. type PDTRegistry struct { - adaptersByName map[string]*PDTAdapter - adaptersByType map[reflect.Type]*PDTAdapter + adaptersByName map[string]*PDTAdapter + adaptersByType map[reflect.Type]*PDTAdapter + primitiveAdaptersByName map[string]*PrimitivePDTAdapter + primitiveAdaptersByType map[reflect.Type]*PrimitivePDTAdapter } // NewPDTRegistry creates an empty PDTRegistry. func NewPDTRegistry() *PDTRegistry { - return &PDTRegistry{adaptersByName: make(map[string]*PDTAdapter), adaptersByType: make(map[reflect.Type]*PDTAdapter)} + return &PDTRegistry{ + adaptersByName: make(map[string]*PDTAdapter), + adaptersByType: make(map[reflect.Type]*PDTAdapter), + primitiveAdaptersByName: make(map[string]*PrimitivePDTAdapter), + primitiveAdaptersByType: make(map[reflect.Type]*PrimitivePDTAdapter), + } } // RegisterFuncs registers hydration/dehydration functions for a type name. @@ -89,6 +103,8 @@ func (r *PDTRegistry) Hydrate(pdt *ProviderDefinedType) interface{} { for k, v := range pdt.Fields { if nested, ok := v.(*ProviderDefinedType); ok { hydratedFields[k] = r.Hydrate(nested) + } else if nested, ok := v.(*PrimitiveProviderDefinedType); ok { + hydratedFields[k] = r.HydratePrimitive(nested) } else { hydratedFields[k] = v } @@ -103,3 +119,39 @@ func (r *PDTRegistry) Hydrate(pdt *ProviderDefinedType) interface{} { } return result } + +// RegisterPrimitiveFuncs registers hydration/dehydration functions for a primitive type name. +func (r *PDTRegistry) RegisterPrimitiveFuncs(typeName string, fromString func(string) (interface{}, error), toString func(interface{}) (string, error)) { + adapter := &PrimitivePDTAdapter{TypeName: typeName, FromString: fromString, ToString: toString} + r.primitiveAdaptersByName[typeName] = adapter +} + +// RegisterPrimitiveFuncsWithType registers hydration/dehydration functions for a primitive type name +// and associates a Go type for dehydration lookup. +func (r *PDTRegistry) RegisterPrimitiveFuncsWithType(typeName string, targetType reflect.Type, fromString func(string) (interface{}, error), toString func(interface{}) (string, error)) { + adapter := &PrimitivePDTAdapter{TypeName: typeName, FromString: fromString, ToString: toString} + r.primitiveAdaptersByName[typeName] = adapter + r.primitiveAdaptersByType[targetType] = adapter +} + +// GetPrimitiveAdapterByType returns the primitive adapter registered for the given Go type, or nil. +func (r *PDTRegistry) GetPrimitiveAdapterByType(t reflect.Type) *PrimitivePDTAdapter { + return r.primitiveAdaptersByType[t] +} + +// HydratePrimitive converts a PrimitiveProviderDefinedType into a domain object using the registered primitive adapter. +// Returns the raw PDT if no adapter is found or if hydration fails. +func (r *PDTRegistry) HydratePrimitive(pdt *PrimitiveProviderDefinedType) interface{} { + if pdt == nil { + return nil + } + adapter, ok := r.primitiveAdaptersByName[pdt.Name] + if !ok { + return pdt + } + result, err := adapter.FromString(pdt.Value) + if err != nil { + return pdt + } + return result +} diff --git a/gremlin-go/driver/pdtRegistry_test.go b/gremlin-go/driver/pdtRegistry_test.go index ff8baa8cf8..3e57ba5e53 100644 --- a/gremlin-go/driver/pdtRegistry_test.go +++ b/gremlin-go/driver/pdtRegistry_test.go @@ -114,3 +114,68 @@ func TestPDTRegistryNestedHydration_UnregisteredOuter(t *testing.T) { // Non-PDT fields remain unchanged. assert.Equal(t, "test", pdt.Fields["label"]) } + +func TestPDTRegistryPrimitiveRegisterFuncsAndHydrate(t *testing.T) { + reg := NewPDTRegistry() + reg.RegisterPrimitiveFuncs("x:Uint32", func(s string) (interface{}, error) { + return "uint32:" + s, nil + }, nil) + + pdt := &PrimitiveProviderDefinedType{Name: "x:Uint32", Value: "42"} + result := reg.HydratePrimitive(pdt) + assert.Equal(t, "uint32:42", result) +} + +func TestPDTRegistryPrimitiveNoAdapterReturnsRawPDT(t *testing.T) { + reg := NewPDTRegistry() + pdt := &PrimitiveProviderDefinedType{Name: "x:Unknown", Value: "val"} + result := reg.HydratePrimitive(pdt) + assert.Equal(t, pdt, result) +} + +func TestPDTRegistryPrimitiveAdapterErrorReturnsRawPDT(t *testing.T) { + reg := NewPDTRegistry() + reg.RegisterPrimitiveFuncs("x:Bad", func(s string) (interface{}, error) { + return nil, errors.New("fail") + }, nil) + + pdt := &PrimitiveProviderDefinedType{Name: "x:Bad", Value: "val"} + result := reg.HydratePrimitive(pdt) + assert.Equal(t, pdt, result) +} + +func TestPDTRegistryPrimitiveHydrateNil(t *testing.T) { + reg := NewPDTRegistry() + assert.Nil(t, reg.HydratePrimitive(nil)) +} + +func TestPDTRegistryPrimitiveInsideComposite(t *testing.T) { + reg := NewPDTRegistry() + reg.RegisterPrimitiveFuncs("x:Uint32", func(s string) (interface{}, error) { + return "uint32:" + s, nil + }, nil) + + inner := &PrimitiveProviderDefinedType{Name: "x:Uint32", Value: "7"} + outer := &ProviderDefinedType{Name: "x:Outer", Fields: map[string]interface{}{"id": inner}} + result := reg.Hydrate(outer) + + pdt, ok := result.(*ProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "uint32:7", pdt.Fields["id"]) +} + +func TestPDTRegistryPrimitiveWithType(t *testing.T) { + type myID string + reg := NewPDTRegistry() + reg.RegisterPrimitiveFuncsWithType("x:MyID", reflect.TypeOf(myID("")), + func(s string) (interface{}, error) { + return myID(s), nil + }, + func(obj interface{}) (string, error) { + return string(obj.(myID)), nil + }) + + adapter := reg.GetPrimitiveAdapterByType(reflect.TypeOf(myID(""))) + assert.NotNil(t, adapter) + assert.Equal(t, "x:MyID", adapter.TypeName) +} diff --git a/gremlin-go/driver/providerDefinedType.go b/gremlin-go/driver/providerDefinedType.go index 2a5a359892..ea58de83ef 100644 --- a/gremlin-go/driver/providerDefinedType.go +++ b/gremlin-go/driver/providerDefinedType.go @@ -48,4 +48,24 @@ func pdtWriter(value interface{}, w io.Writer, typeSerializer *graphBinaryTypeSe m[k] = v } return typeSerializer.write(m, w) -} \ No newline at end of file +} + +// PrimitiveProviderDefinedType represents a primitive provider-defined type (PDT) in GraphBinary serialization. +// Wire format 0xf1: two fully-qualified Strings {name}{value}. +type PrimitiveProviderDefinedType struct { + Name string + Value string +} + +func (p *PrimitiveProviderDefinedType) String() string { + return fmt.Sprintf("pdt[%s]%s", p.Name, p.Value) +} + +// primitivePdtWriter serializes a PrimitiveProviderDefinedType as two fully-qualified strings. +func primitivePdtWriter(value interface{}, w io.Writer, typeSerializer *graphBinaryTypeSerializer) error { + pdt := value.(*PrimitiveProviderDefinedType) + if err := typeSerializer.write(pdt.Name, w); err != nil { + return err + } + return typeSerializer.write(pdt.Value, w) +} diff --git a/gremlin-go/driver/providerDefinedType_test.go b/gremlin-go/driver/providerDefinedType_test.go index 8d0f394e24..79673ca263 100644 --- a/gremlin-go/driver/providerDefinedType_test.go +++ b/gremlin-go/driver/providerDefinedType_test.go @@ -34,3 +34,13 @@ func TestProviderDefinedType(t *testing.T) { assert.Contains(t, pdt.String(), "pdt[com.example.Test]") }) } + +func TestPrimitiveProviderDefinedType(t *testing.T) { + t.Run("String method", func(t *testing.T) { + pdt := &PrimitiveProviderDefinedType{ + Name: "x:Uint32", + Value: "42", + } + assert.Contains(t, pdt.String(), "pdt[x:Uint32]42") + }) +} diff --git a/gremlin-go/driver/serializer.go b/gremlin-go/driver/serializer.go index 38d618600c..859acc0c28 100644 --- a/gremlin-go/driver/serializer.go +++ b/gremlin-go/driver/serializer.go @@ -222,7 +222,6 @@ func initSerializers() { byteBuffer: byteBufferWriter, markerType: markerWriter, compositePDTType: pdtWriter, + primitivePDTType: primitivePdtWriter, } } - - diff --git a/gremlin-go/driver/traversal_test.go b/gremlin-go/driver/traversal_test.go index 40a20d2904..9e3a8e3da5 100644 --- a/gremlin-go/driver/traversal_test.go +++ b/gremlin-go/driver/traversal_test.go @@ -21,6 +21,7 @@ package gremlingo import ( "crypto/tls" + "fmt" "reflect" "strings" "testing" @@ -577,4 +578,143 @@ func TestProviderDefinedTypeTraversalAPIIntegration(t *testing.T) { type regPoint struct { X int32 Y int32 -} \ No newline at end of file +} + +func TestPrimitiveProviderDefinedTypeTraversalAPIIntegration(t *testing.T) { + testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) + testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) + + t.Run("unregistered raw primitive PDT round-trip", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + remote, err := NewDriverRemoteConnection(testNoAuthUrl, + func(settings *DriverRemoteConnectionSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + }) + require.NoError(t, err) + defer remote.Close() + + g := Traversal_().With(remote) + pdt := &PrimitiveProviderDefinedType{Name: "UnregisteredPrim", Value: "00123"} + + results, err := g.Inject(pdt).ToList() + require.NoError(t, err) + require.Len(t, results, 1) + + result, ok := results[0].GetInterface().(*PrimitiveProviderDefinedType) + require.True(t, ok, "expected *PrimitiveProviderDefinedType, got %T", results[0].GetInterface()) + assert.Equal(t, "UnregisteredPrim", result.Name) + assert.Equal(t, "00123", result.Value) + }) + + t.Run("registered primitive PDT auto-dehydrate and hydrate", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + type tinkerId string + registry := NewPDTRegistry() + registry.RegisterPrimitiveFuncsWithType("x:TinkerId", reflect.TypeOf(tinkerId("")), + func(s string) (interface{}, error) { return tinkerId(s), nil }, + func(obj interface{}) (string, error) { return string(obj.(tinkerId)), nil }) + + remote, err := NewDriverRemoteConnection(testNoAuthUrl, + func(settings *DriverRemoteConnectionSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + settings.PDTRegistry = registry + }) + require.NoError(t, err) + defer remote.Close() + + g := Traversal_().With(remote) + id := tinkerId("abc-123") + + results, err := g.Inject(id).ToList() + require.NoError(t, err) + require.Len(t, results, 1) + + result, ok := results[0].GetInterface().(tinkerId) + require.True(t, ok, "expected tinkerId, got %T", results[0].GetInterface()) + assert.Equal(t, tinkerId("abc-123"), result) + }) + + t.Run("registered primitive nested in composite", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + type myUint32 uint32 + registry := NewPDTRegistry() + registry.RegisterPrimitiveFuncsWithType("x:MyUint32", reflect.TypeOf(myUint32(0)), + func(s string) (interface{}, error) { + var v uint32 + fmt.Sscanf(s, "%d", &v) + return myUint32(v), nil + }, + func(obj interface{}) (string, error) { + return fmt.Sprintf("%d", obj.(myUint32)), nil + }) + + remote, err := NewDriverRemoteConnection(testNoAuthUrl, + func(settings *DriverRemoteConnectionSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + settings.PDTRegistry = registry + }) + require.NoError(t, err) + defer remote.Close() + + g := Traversal_().With(remote) + // Use a composite PDT with a primitive field + pdt := &ProviderDefinedType{ + Name: "Measurement", + Fields: map[string]interface{}{ + "unit": "kg", + "value": &PrimitiveProviderDefinedType{Name: "x:MyUint32", Value: "100"}, + }, + } + + results, err := g.Inject(pdt).ToList() + require.NoError(t, err) + require.Len(t, results, 1) + + result, ok := results[0].GetInterface().(*ProviderDefinedType) + require.True(t, ok, "expected *ProviderDefinedType, got %T", results[0].GetInterface()) + assert.Equal(t, "Measurement", result.Name) + assert.Equal(t, "kg", result.Fields["unit"]) + // The primitive value should be hydrated to myUint32 + hydrated, ok := result.Fields["value"].(myUint32) + require.True(t, ok, "expected myUint32, got %T", result.Fields["value"]) + assert.Equal(t, myUint32(100), hydrated) + }) + + t.Run("primitive PDT in collection", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + remote, err := NewDriverRemoteConnection(testNoAuthUrl, + func(settings *DriverRemoteConnectionSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + }) + require.NoError(t, err) + defer remote.Close() + + g := Traversal_().With(remote) + list := []interface{}{ + &PrimitiveProviderDefinedType{Name: "x:Val", Value: "one"}, + &PrimitiveProviderDefinedType{Name: "x:Val", Value: "two"}, + } + + results, err := g.Inject(list).ToList() + require.NoError(t, err) + require.Len(t, results, 1) + + resultList, ok := results[0].GetInterface().([]interface{}) + require.True(t, ok, "expected []interface{}, got %T", results[0].GetInterface()) + require.Len(t, resultList, 2) + + p1, ok := resultList[0].(*PrimitiveProviderDefinedType) + require.True(t, ok) + assert.Equal(t, "one", p1.Value) + + p2, ok := resultList[1].(*PrimitiveProviderDefinedType) + require.True(t, ok) + assert.Equal(t, "two", p2.Value) + }) +}
