This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 1331381ae feat(go): add support for bfloat16 (#3310)
1331381ae is described below

commit 1331381ae968d792a8c7ab54736bb8c5a0084100
Author: Chang-Yen (Brian) Li <[email protected]>
AuthorDate: Wed Feb 11 15:39:28 2026 -0600

    feat(go): add support for bfloat16 (#3310)
    
    ## Why?
    
    We want to use `bfloat16` (BF16) in FDL to reduce payload size while
    keeping a wide exponent range (common in ML/AI workflows). Fory
    currently lacks a BF16 primitive and optimized arrays. This PR adds full
    support for `bfloat16` in the Fory Go runtime and codegen.
    
    ## What does this PR do?
    
    1.  **Compiler**:
        *   Adds `bfloat16` to the IR type system (`PrimitiveKind`).
    * Updates the Go generator to map `bfloat16` to `bfloat16.BFloat16` and
    handle imports.
        *   Adds codegen tests for `bfloat16`.
    2.  **Go Runtime**:
    * Adds a new package `go/fory/bfloat16` with a strong type `type
    BFloat16 uint16`.
    * Implements IEEE 754 compliant conversions between `float32` and
    `bfloat16` (round-to-nearest, ties-to-even).
        *   Handles special values: NaN, Inf, ±0 correctly.
    3.  **Serialization**:
    * Implements optimized serializers for `bfloat16`, `[]bfloat16`, and
    `[N]bfloat16`, mirroring the `float16` implementation structure.
    4.  **Type System**:
    * Updates `TypeResolver` to register `BFloat16`, `[]BFloat16`, and
    `[N]BFloat16` types.
    * Adds type resolution logic to distinguish `BFloat16` from `uint16`
    (since it's an alias type) for correct serializer selection.
    5. **Tests**: Adds comprehensive unit tests for conversions, rounding
    logic, serialization, and codegen.
    
    ## Related issues
    
    Fixes #3284
    
    ## Does this PR introduce any user-facing change?
    
    
    
    - [x] Does this PR introduce any public API change?
    - [x] Does this PR introduce any binary protocol compatibility change?
    
    ## Benchmark
---
 compiler/fory_compiler/generators/go.py            |  3 +
 compiler/fory_compiler/ir/types.py                 |  2 +
 .../fory_compiler/tests/test_generated_code.py     | 30 +++++++
 go/fory/array_primitive.go                         | 75 +++++++++++++++++
 go/fory/array_primitive_test.go                    | 16 ++++
 go/fory/bfloat16/bfloat16.go                       | 72 ++++++++++++++++
 go/fory/bfloat16/bfloat16_test.go                  | 96 ++++++++++++++++++++++
 go/fory/primitive.go                               | 53 ++++++++++++
 go/fory/primitive_test.go                          | 34 ++++++++
 go/fory/skip.go                                    |  4 +-
 go/fory/slice_primitive.go                         | 79 ++++++++++++++++++
 go/fory/slice_primitive_test.go                    | 43 ++++++++++
 go/fory/type_resolver.go                           | 14 ++++
 go/fory/type_test.go                               |  6 ++
 14 files changed, 526 insertions(+), 1 deletion(-)

diff --git a/compiler/fory_compiler/generators/go.py 
b/compiler/fory_compiler/generators/go.py
index d7784a0df..0fd9d4b0c 100644
--- a/compiler/fory_compiler/generators/go.py
+++ b/compiler/fory_compiler/generators/go.py
@@ -190,6 +190,7 @@ class GoGenerator(BaseGenerator):
         PrimitiveKind.VAR_UINT64: "uint64",
         PrimitiveKind.TAGGED_UINT64: "uint64",
         PrimitiveKind.FLOAT16: "float16.Float16",
+        PrimitiveKind.BFLOAT16: "bfloat16.BFloat16",
         PrimitiveKind.FLOAT32: "float32",
         PrimitiveKind.FLOAT64: "float64",
         PrimitiveKind.STRING: "string",
@@ -1090,6 +1091,8 @@ class GoGenerator(BaseGenerator):
                 imports.add('"time"')
             elif field_type.kind == PrimitiveKind.FLOAT16:
                 imports.add('float16 "github.com/apache/fory/go/fory/float16"')
+            elif field_type.kind == PrimitiveKind.BFLOAT16:
+                imports.add('bfloat16 
"github.com/apache/fory/go/fory/bfloat16"')
 
         elif isinstance(field_type, ListType):
             self.collect_imports(field_type.element_type, imports)
diff --git a/compiler/fory_compiler/ir/types.py 
b/compiler/fory_compiler/ir/types.py
index 3dfc3d8ed..facc95ef6 100644
--- a/compiler/fory_compiler/ir/types.py
+++ b/compiler/fory_compiler/ir/types.py
@@ -39,6 +39,7 @@ class PrimitiveKind(PyEnum):
     VAR_UINT64 = "var_uint64"
     TAGGED_UINT64 = "tagged_uint64"
     FLOAT16 = "float16"
+    BFLOAT16 = "bfloat16"
     FLOAT32 = "float32"
     FLOAT64 = "float64"
     STRING = "string"
@@ -67,6 +68,7 @@ PRIMITIVE_TYPES = {
     "fixed_uint64": PrimitiveKind.UINT64,
     "tagged_uint64": PrimitiveKind.TAGGED_UINT64,
     "float16": PrimitiveKind.FLOAT16,
+    "bfloat16": PrimitiveKind.BFLOAT16,
     "float32": PrimitiveKind.FLOAT32,
     "float64": PrimitiveKind.FLOAT64,
     "string": PrimitiveKind.STRING,
diff --git a/compiler/fory_compiler/tests/test_generated_code.py 
b/compiler/fory_compiler/tests/test_generated_code.py
index 7fda28f48..13dc99eab 100644
--- a/compiler/fory_compiler/tests/test_generated_code.py
+++ b/compiler/fory_compiler/tests/test_generated_code.py
@@ -497,3 +497,33 @@ def test_generated_code_tree_ref_options_equivalent():
 
     cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator))
     assert "SharedWeak<TreeNode>" in cpp_output
+
+
+def test_go_bfloat16_generation():
+    idl = dedent(
+        """
+        package bfloat16_test;
+
+        message BFloat16Message {
+            bfloat16 val = 1;
+            optional bfloat16 opt_val = 2;
+            list<bfloat16> list_val = 3;
+        }
+        """
+    )
+    schema = parse_fdl(idl)
+    files = generate_files(schema, GoGenerator)
+
+    assert len(files) == 1
+    content = list(files.values())[0]
+
+    # Check imports
+    assert 'bfloat16 "github.com/apache/fory/go/fory/bfloat16"' in content
+
+    # Check fields
+    assert '\tVal bfloat16.BFloat16 `fory:"id=1"`' in content
+    assert (
+        '\tOptVal optional.Optional[bfloat16.BFloat16] `fory:"id=2,nullable"`'
+        in content
+    )
+    assert "\tListVal []bfloat16.BFloat16" in content
diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go
index 977773359..27813060b 100644
--- a/go/fory/array_primitive.go
+++ b/go/fory/array_primitive.go
@@ -21,6 +21,7 @@ import (
        "reflect"
        "unsafe"
 
+       "github.com/apache/fory/go/fory/bfloat16"
        "github.com/apache/fory/go/fory/float16"
 )
 
@@ -871,3 +872,77 @@ func (s float16ArraySerializer) Read(ctx *ReadContext, 
refMode RefMode, readType
 func (s float16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode 
RefMode, typeInfo *TypeInfo, value reflect.Value) {
        s.Read(ctx, refMode, false, false, value)
 }
+
+// ============================================================================
+// bfloat16ArraySerializer - optimized [N]bfloat16.BFloat16 serialization
+// ============================================================================
+
+type bfloat16ArraySerializer struct {
+       arrayType reflect.Type
+}
+
+func (s bfloat16ArraySerializer) WriteData(ctx *WriteContext, value 
reflect.Value) {
+       buf := ctx.Buffer()
+       length := value.Len()
+       size := length * 2
+       buf.WriteLength(size)
+       if length > 0 {
+               if value.CanAddr() && isLittleEndian {
+                       ptr := value.Addr().UnsafePointer()
+                       buf.WriteBinary(unsafe.Slice((*byte)(ptr), size))
+               } else {
+                       for i := 0; i < length; i++ {
+                               // We can't easily cast the whole array if not 
addressable/little-endian
+                               // So we iterate.
+                               val := 
value.Index(i).Interface().(bfloat16.BFloat16)
+                               buf.WriteUint16(val.Bits())
+                       }
+               }
+       }
+}
+
+func (s bfloat16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, 
writeType bool, hasGenerics bool, value reflect.Value) {
+       writeArrayRefAndType(ctx, refMode, writeType, value, BFLOAT16_ARRAY)
+       if ctx.HasError() {
+               return
+       }
+       s.WriteData(ctx, value)
+}
+
+func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value 
reflect.Value) {
+       buf := ctx.Buffer()
+       ctxErr := ctx.Err()
+       size := buf.ReadLength(ctxErr)
+       length := size / 2
+       if ctx.HasError() {
+               return
+       }
+       if length != value.Type().Len() {
+               ctx.SetError(DeserializationErrorf("array length %d does not 
match type %v", length, value.Type()))
+               return
+       }
+
+       if length > 0 {
+               if isLittleEndian {
+                       ptr := value.Addr().UnsafePointer()
+                       raw := buf.ReadBinary(size, ctxErr)
+                       copy(unsafe.Slice((*byte)(ptr), size), raw)
+               } else {
+                       for i := 0; i < length; i++ {
+                               
value.Index(i).Set(reflect.ValueOf(bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr))))
+                       }
+               }
+       }
+}
+
+func (s bfloat16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, 
readType bool, hasGenerics bool, value reflect.Value) {
+       done := readArrayRefAndType(ctx, refMode, readType, value)
+       if done || ctx.HasError() {
+               return
+       }
+       s.ReadData(ctx, value)
+}
+
+func (s bfloat16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode 
RefMode, typeInfo *TypeInfo, value reflect.Value) {
+       s.Read(ctx, refMode, false, false, value)
+}
diff --git a/go/fory/array_primitive_test.go b/go/fory/array_primitive_test.go
index c2e684af4..e8c99fb2f 100644
--- a/go/fory/array_primitive_test.go
+++ b/go/fory/array_primitive_test.go
@@ -20,6 +20,7 @@ package fory
 import (
        "testing"
 
+       "github.com/apache/fory/go/fory/bfloat16"
        "github.com/apache/fory/go/fory/float16"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
@@ -74,6 +75,21 @@ func TestPrimitiveArraySerializer(t *testing.T) {
                require.NoError(t, err)
                require.Equal(t, arr, result)
        })
+
+       t.Run("bfloat16_array", func(t *testing.T) {
+               arr := [3]bfloat16.BFloat16{
+                       bfloat16.BFloat16FromFloat32(1.0),
+                       bfloat16.BFloat16FromFloat32(2.5),
+                       bfloat16.BFloat16FromFloat32(-3.5),
+               }
+               data, err := f.Serialize(arr)
+               assert.NoError(t, err)
+
+               var result [3]bfloat16.BFloat16
+               err = f.Deserialize(data, &result)
+               assert.NoError(t, err)
+               assert.Equal(t, arr, result)
+       })
 }
 
 func TestArraySliceInteroperability(t *testing.T) {
diff --git a/go/fory/bfloat16/bfloat16.go b/go/fory/bfloat16/bfloat16.go
new file mode 100644
index 000000000..6b31ffdb9
--- /dev/null
+++ b/go/fory/bfloat16/bfloat16.go
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bfloat16
+
+import (
+       "fmt"
+       "math"
+)
+
+// BFloat16 represents a brain floating point number (bfloat16).
+// It is stored as a uint16.
+type BFloat16 uint16
+
+// BFloat16FromBits returns the BFloat16 corresponding to the given bit 
pattern.
+func BFloat16FromBits(b uint16) BFloat16 {
+       return BFloat16(b)
+}
+
+// Bits returns the raw bit pattern of the floating point number.
+func (f BFloat16) Bits() uint16 {
+       return uint16(f)
+}
+
+// BFloat16FromFloat32 converts a float32 to a BFloat16.
+// Rounds to nearest, ties to even.
+func BFloat16FromFloat32(f float32) BFloat16 {
+       u := math.Float32bits(f)
+
+       // NaN check
+       if (u&0x7F800000) == 0x7F800000 && (u&0x007FFFFF) != 0 {
+               return BFloat16(0x7FC0) // Canonical NaN
+       }
+
+       // Fast path for rounding
+       // We want to add a rounding bias and then truncate.
+       // For ties-to-even:
+       // If LSB of result (bit 16) is 0: Rounding bias is 0x7FFF
+       // If LSB of result (bit 16) is 1: Rounding bias is 0x8000
+       // lsb is (u >> 16) & 1.
+       // bias = 0x7FFF + lsb
+
+       lsb := (u >> 16) & 1
+       roundingBias := uint32(0x7FFF) + lsb
+       u += roundingBias
+       return BFloat16(u >> 16)
+}
+
+// Float32 returns the float32 representation of the BFloat16.
+func (f BFloat16) Float32() float32 {
+       // Just shift left by 16 bits
+       return math.Float32frombits(uint32(f) << 16)
+}
+
+// String returns the string representation of f.
+func (f BFloat16) String() string {
+       return fmt.Sprintf("%g", f.Float32())
+}
diff --git a/go/fory/bfloat16/bfloat16_test.go 
b/go/fory/bfloat16/bfloat16_test.go
new file mode 100644
index 000000000..44745ae91
--- /dev/null
+++ b/go/fory/bfloat16/bfloat16_test.go
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bfloat16_test
+
+import (
+       "math"
+       "testing"
+
+       "github.com/apache/fory/go/fory/bfloat16"
+       "github.com/stretchr/testify/assert"
+)
+
+func TestBFloat16_Conversion(t *testing.T) {
+       tests := []struct {
+               name  string
+               f32   float32
+               want  uint16 // bits
+               check bool   // if true, check exact bits
+       }{
+               {"Zero", 0.0, 0x0000, true},
+               {"NegZero", float32(math.Copysign(0, -1)), 0x8000, true},
+               {"One", 1.0, 0x3F80, true},
+               {"MinusOne", -1.0, 0xBF80, true},
+               {"Inf", float32(math.Inf(1)), 0x7F80, true},
+               {"NegInf", float32(math.Inf(-1)), 0xFF80, true},
+               // 1.5 -> 0x3FC0. (0x3FC00000 is 1.5)
+               {"OnePointFive", 1.5, 0x3FC0, true},
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       bf16 := bfloat16.BFloat16FromFloat32(tt.f32)
+                       if tt.check {
+                               assert.Equal(t, tt.want, bf16.Bits(), "Bits 
match")
+                       }
+
+                       // Round trip check
+                       roundTrip := bf16.Float32()
+                       if math.IsInf(float64(tt.f32), 0) {
+                               assert.True(t, math.IsInf(float64(roundTrip), 
0))
+                               assert.Equal(t, math.Signbit(float64(tt.f32)), 
math.Signbit(float64(roundTrip)))
+                       } else if math.IsNaN(float64(tt.f32)) {
+                               assert.True(t, math.IsNaN(float64(roundTrip)))
+                       } else {
+                               if tt.check {
+                                       assert.Equal(t, tt.f32, roundTrip, 
"Round trip value match")
+                               }
+                       }
+               })
+       }
+}
+
+func TestBFloat16_Rounding(t *testing.T) {
+       // BFloat16 has 7 bits of mantissa. For 1.0, ULP is 2^-7, and half ULP 
is 2^-8.
+       // Values are rounded to nearest even.  1.0 + 2^-8 should round to 1.0 
(even mantissa).
+
+       // The float32 representation of 1.0 is 0x3F800000.
+       // Adding 2^-8 (1/256) means setting bit 15 (23-8).
+       // So, 1.0 + 2^-8 in float32 is 0x3F808000.
+       val1 := math.Float32frombits(0x3F808000) // 1.0 + 2^-8
+       bf1 := bfloat16.BFloat16FromFloat32(val1)
+       assert.Equal(t, uint16(0x3F80), bf1.Bits(), "Round to even (down)")
+
+       // For 1.0 + 3 * 2^-8 (1.5 ULP), bits 15 and 14 are set,
+       // making the float32 representation 0x3F80C000. This rounds up.
+       val2 := math.Float32frombits(0x3F80C000)
+       bf2 := bfloat16.BFloat16FromFloat32(val2)
+       assert.Equal(t, uint16(0x3F81), bf2.Bits(), "Round up")
+
+       // 1.0 + 2^-7 is the next representable number after 1.0. In float32, 
this is 0x3F810000.
+       val3 := math.Float32frombits(0x3F810000)
+       bf3 := bfloat16.BFloat16FromFloat32(val3)
+       assert.Equal(t, uint16(0x3F81), bf3.Bits(), "Exact")
+
+       // For 1.0 + 2^-7 + 2^-8 (0x3F818000), the LSB (bit 16) of 0x3F81 is 1 
(odd),
+       // and the guard bit (bit 15) is 1. Rounding to nearest even means 
rounding up.
+       // Result: 0x3F82.
+       val4 := math.Float32frombits(0x3F818000)
+       bf4 := bfloat16.BFloat16FromFloat32(val4)
+       assert.Equal(t, uint16(0x3F82), bf4.Bits(), "Round to even (up)")
+}
diff --git a/go/fory/primitive.go b/go/fory/primitive.go
index 2c316d549..998879042 100644
--- a/go/fory/primitive.go
+++ b/go/fory/primitive.go
@@ -663,3 +663,56 @@ func (s float16Serializer) Read(ctx *ReadContext, refMode 
RefMode, readType bool
 func (s float16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, 
typeInfo *TypeInfo, value reflect.Value) {
        s.Read(ctx, refMode, false, false, value)
 }
+
+// ============================================================================
+// bfloat16Serializer - optimized bfloat16 serialization
+// ============================================================================
+
+// bfloat16Serializer handles bfloat16 type
+type bfloat16Serializer struct{}
+
+var globalBFloat16Serializer = bfloat16Serializer{}
+
+func (s bfloat16Serializer) WriteData(ctx *WriteContext, value reflect.Value) {
+       // Value is effectively uint16 (alias)
+       ctx.buffer.WriteUint16(uint16(value.Uint()))
+}
+
+func (s bfloat16Serializer) Write(ctx *WriteContext, refMode RefMode, 
writeType bool, hasGenerics bool, value reflect.Value) {
+       if refMode != RefModeNone {
+               ctx.buffer.WriteInt8(NotNullValueFlag)
+       }
+       if writeType {
+               ctx.buffer.WriteUint8(uint8(BFLOAT16))
+       }
+       s.WriteData(ctx, value)
+}
+
+func (s bfloat16Serializer) ReadData(ctx *ReadContext, value reflect.Value) {
+       err := ctx.Err()
+       bits := ctx.buffer.ReadUint16(err)
+       if ctx.HasError() {
+               return
+       }
+       value.SetUint(uint64(bits))
+}
+
+func (s bfloat16Serializer) Read(ctx *ReadContext, refMode RefMode, readType 
bool, hasGenerics bool, value reflect.Value) {
+       err := ctx.Err()
+       if refMode != RefModeNone {
+               if ctx.buffer.ReadInt8(err) == NullFlag {
+                       return
+               }
+       }
+       if readType {
+               _ = ctx.buffer.ReadUint8(err)
+       }
+       if ctx.HasError() {
+               return
+       }
+       s.ReadData(ctx, value)
+}
+
+func (s bfloat16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode 
RefMode, typeInfo *TypeInfo, value reflect.Value) {
+       s.Read(ctx, refMode, false, false, value)
+}
diff --git a/go/fory/primitive_test.go b/go/fory/primitive_test.go
index 978a81d46..0f478590e 100644
--- a/go/fory/primitive_test.go
+++ b/go/fory/primitive_test.go
@@ -20,6 +20,7 @@ package fory
 import (
        "testing"
 
+       "github.com/apache/fory/go/fory/bfloat16"
        "github.com/apache/fory/go/fory/float16"
        "github.com/stretchr/testify/require"
 )
@@ -56,3 +57,36 @@ func TestFloat16PrimitiveSliceDirect(t *testing.T) {
        require.NoError(t, err)
        require.Equal(t, slice, resSlice)
 }
+
+func TestBFloat16Primitive(t *testing.T) {
+       f := New(WithXlang(true))
+       bf16 := bfloat16.BFloat16FromFloat32(3.14)
+
+       // Directly serialize a bfloat16 value
+       data, err := f.Serialize(bf16)
+       require.NoError(t, err)
+
+       var res bfloat16.BFloat16
+       err = f.Deserialize(data, &res)
+       require.NoError(t, err)
+
+       require.Equal(t, bf16.Bits(), res.Bits())
+
+       // Value check (approximate because BF16 precision is low)
+       require.InDelta(t, 3.14, res.Float32(), 0.1)
+}
+
+func TestBFloat16PrimitiveSliceDirect(t *testing.T) {
+       // Tests serializing a slice as a root object
+       f := New(WithXlang(true))
+       bf16 := bfloat16.BFloat16FromFloat32(3.14)
+
+       slice := []bfloat16.BFloat16{bf16, bfloat16.BFloat16(0)}
+       data, err := f.Serialize(slice)
+       require.NoError(t, err)
+
+       var resSlice []bfloat16.BFloat16
+       err = f.Deserialize(data, &resSlice)
+       require.NoError(t, err)
+       require.Equal(t, slice, resSlice)
+}
diff --git a/go/fory/skip.go b/go/fory/skip.go
index b0b36c5af..34005ad74 100644
--- a/go/fory/skip.go
+++ b/go/fory/skip.go
@@ -576,6 +576,8 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, 
readRefFlag bool, isField bo
                _ = ctx.buffer.ReadVarint64(err)
 
        // Floating point types
+       case BFLOAT16, FLOAT16:
+               _ = ctx.buffer.ReadUint16(err)
        case FLOAT32:
                _ = ctx.buffer.ReadFloat32(err)
        case FLOAT64:
@@ -610,7 +612,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, 
readRefFlag bool, isField bo
                        return
                }
                _ = ctx.buffer.ReadBinary(length, err)
-       case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY:
+       case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY:
                length := ctx.buffer.ReadLength(err)
                if ctx.HasError() {
                        return
diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go
index 200d144c6..e4daf990b 100644
--- a/go/fory/slice_primitive.go
+++ b/go/fory/slice_primitive.go
@@ -22,6 +22,7 @@ import (
        "strconv"
        "unsafe"
 
+       "github.com/apache/fory/go/fory/bfloat16"
        "github.com/apache/fory/go/fory/float16"
 )
 
@@ -1330,3 +1331,81 @@ func ReadStringSlice(buf *ByteBuffer, err *Error) 
[]string {
        }
        return result
 }
+
+// ============================================================================
+// bfloat16SliceSerializer - optimized []bfloat16.BFloat16 serialization
+// ============================================================================
+
+type bfloat16SliceSerializer struct{}
+
+func (s bfloat16SliceSerializer) WriteData(ctx *WriteContext, value 
reflect.Value) {
+       // Cast to []bfloat16.BFloat16
+       v := value.Interface().([]bfloat16.BFloat16)
+       buf := ctx.Buffer()
+       length := len(v)
+       size := length * 2
+       buf.WriteLength(size)
+       if length > 0 {
+               ptr := unsafe.Pointer(&v[0])
+               if isLittleEndian {
+                       buf.WriteBinary(unsafe.Slice((*byte)(ptr), size))
+               } else {
+                       for i := 0; i < length; i++ {
+                               buf.WriteUint16(v[i].Bits())
+                       }
+               }
+       }
+}
+
+func (s bfloat16SliceSerializer) Write(ctx *WriteContext, refMode RefMode, 
writeType bool, hasGenerics bool, value reflect.Value) {
+       done := writeSliceRefAndType(ctx, refMode, writeType, value, 
BFLOAT16_ARRAY)
+       if done || ctx.HasError() {
+               return
+       }
+       s.WriteData(ctx, value)
+}
+
+func (s bfloat16SliceSerializer) Read(ctx *ReadContext, refMode RefMode, 
readType bool, hasGenerics bool, value reflect.Value) {
+       done, typeId := readSliceRefAndType(ctx, refMode, readType, value)
+       if done || ctx.HasError() {
+               return
+       }
+       if readType && typeId != uint32(BFLOAT16_ARRAY) {
+               ctx.SetError(DeserializationErrorf("slice type mismatch: 
expected BFLOAT16_ARRAY (%d), got %d", BFLOAT16_ARRAY, typeId))
+               return
+       }
+       s.ReadData(ctx, value)
+}
+
+func (s bfloat16SliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode 
RefMode, typeInfo *TypeInfo, value reflect.Value) {
+       s.Read(ctx, refMode, false, false, value)
+}
+
+func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value 
reflect.Value) {
+       buf := ctx.Buffer()
+       ctxErr := ctx.Err()
+       size := buf.ReadLength(ctxErr)
+       length := size / 2
+       if ctx.HasError() {
+               return
+       }
+
+       ptr := (*[]bfloat16.BFloat16)(value.Addr().UnsafePointer())
+       if length == 0 {
+               *ptr = make([]bfloat16.BFloat16, 0)
+               return
+       }
+
+       result := make([]bfloat16.BFloat16, length)
+
+       if isLittleEndian {
+               raw := buf.ReadBinary(size, ctxErr)
+               targetPtr := unsafe.Pointer(&result[0])
+               copy(unsafe.Slice((*byte)(targetPtr), size), raw)
+       } else {
+               for i := 0; i < length; i++ {
+                       result[i] = 
bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr))
+               }
+       }
+       *ptr = result
+}
diff --git a/go/fory/slice_primitive_test.go b/go/fory/slice_primitive_test.go
index 82959d0da..e03a4cb00 100644
--- a/go/fory/slice_primitive_test.go
+++ b/go/fory/slice_primitive_test.go
@@ -21,6 +21,7 @@ import (
        "math"
        "testing"
 
+       "github.com/apache/fory/go/fory/bfloat16"
        "github.com/apache/fory/go/fory/float16"
        "github.com/stretchr/testify/assert"
 )
@@ -67,6 +68,48 @@ func TestFloat16Slice(t *testing.T) {
        })
 }
 
+func TestBFloat16Slice(t *testing.T) {
+       f := NewFory()
+
+       t.Run("bfloat16_slice", func(t *testing.T) {
+               slice := []bfloat16.BFloat16{
+                       bfloat16.BFloat16FromFloat32(1.0),
+                       bfloat16.BFloat16FromFloat32(2.5),
+                       bfloat16.BFloat16FromFloat32(-3.5),
+               }
+               data, err := f.Serialize(slice)
+               assert.NoError(t, err)
+
+               var result []bfloat16.BFloat16
+               err = f.Deserialize(data, &result)
+               assert.NoError(t, err)
+               assert.Equal(t, slice, result)
+       })
+
+       t.Run("bfloat16_slice_empty", func(t *testing.T) {
+               slice := []bfloat16.BFloat16{}
+               data, err := f.Serialize(slice)
+               assert.NoError(t, err)
+
+               var result []bfloat16.BFloat16
+               err = f.Deserialize(data, &result)
+               assert.NoError(t, err)
+               assert.NotNil(t, result)
+               assert.Empty(t, result)
+       })
+
+       t.Run("bfloat16_slice_nil", func(t *testing.T) {
+               var slice []bfloat16.BFloat16 = nil
+               data, err := f.Serialize(slice)
+               assert.NoError(t, err)
+
+               var result []bfloat16.BFloat16
+               err = f.Deserialize(data, &result)
+               assert.NoError(t, err)
+               assert.Nil(t, result)
+       })
+}
+
 func TestIntSlice(t *testing.T) {
        f := NewFory()
 
diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go
index b10a134d8..19cd83997 100644
--- a/go/fory/type_resolver.go
+++ b/go/fory/type_resolver.go
@@ -29,6 +29,7 @@ import (
        "time"
        "unsafe"
 
+       "github.com/apache/fory/go/fory/bfloat16"
        "github.com/apache/fory/go/fory/float16"
        "github.com/apache/fory/go/fory/meta"
 )
@@ -77,6 +78,7 @@ var (
        float32SliceType     = reflect.TypeOf((*[]float32)(nil)).Elem()
        float64SliceType     = reflect.TypeOf((*[]float64)(nil)).Elem()
        float16SliceType     = reflect.TypeOf((*[]float16.Float16)(nil)).Elem()
+       bfloat16SliceType    = 
reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem()
        interfaceSliceType   = reflect.TypeOf((*[]any)(nil)).Elem()
        interfaceMapType     = reflect.TypeOf((*map[any]any)(nil)).Elem()
        stringStringMapType  = reflect.TypeOf((*map[string]string)(nil)).Elem()
@@ -103,6 +105,7 @@ var (
        float32Type          = reflect.TypeOf((*float32)(nil)).Elem()
        float64Type          = reflect.TypeOf((*float64)(nil)).Elem()
        float16Type          = reflect.TypeOf((*float16.Float16)(nil)).Elem()
+       bfloat16Type         = reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem()
        dateType             = reflect.TypeOf((*Date)(nil)).Elem()
        timestampType        = reflect.TypeOf((*time.Time)(nil)).Elem()
        genericSetType       = reflect.TypeOf((*Set[any])(nil)).Elem()
@@ -259,6 +262,7 @@ func newTypeResolver(fory *Fory) *TypeResolver {
                float32Type,
                float64Type,
                float16Type,
+               bfloat16Type,
                stringType,
                dateType,
                timestampType,
@@ -416,6 +420,7 @@ func (r *TypeResolver) initialize() {
                {float32SliceType, FLOAT32_ARRAY, float32SliceSerializer{}},
                {float64SliceType, FLOAT64_ARRAY, float64SliceSerializer{}},
                {float16SliceType, FLOAT16_ARRAY, float16SliceSerializer{}},
+               {bfloat16SliceType, BFLOAT16_ARRAY, bfloat16SliceSerializer{}},
                // Register common map types for fast path with optimized 
serializers
                {stringStringMapType, MAP, stringStringMapSerializer{}},
                {stringInt64MapType, MAP, stringInt64MapSerializer{}},
@@ -440,6 +445,7 @@ func (r *TypeResolver) initialize() {
                {float32Type, FLOAT32, float32Serializer{}},
                {float64Type, FLOAT64, float64Serializer{}},
                {float16Type, FLOAT16, float16Serializer{}},
+               {bfloat16Type, BFLOAT16, bfloat16Serializer{}},
                {dateType, DATE, dateSerializer{}},
                {timestampType, TIMESTAMP, timeSerializer{}},
                {genericSetType, SET, setSerializer{}},
@@ -1712,6 +1718,10 @@ func (r *TypeResolver) createSerializer(type_ 
reflect.Type, mapInStruct bool) (s
                        if elem == float16Type {
                                return float16SliceSerializer{}, nil
                        }
+                       // Check for fory.BFloat16 (aliased to uint16)
+                       if elem == bfloat16Type {
+                               return bfloat16SliceSerializer{}, nil
+                       }
                        return uint16SliceSerializer{}, nil
                case reflect.Uint32:
                        return uint32SliceSerializer{}, nil
@@ -1763,6 +1773,10 @@ func (r *TypeResolver) createSerializer(type_ 
reflect.Type, mapInStruct bool) (s
                        if elem == float16Type {
                                return float16ArraySerializer{arrayType: 
type_}, nil
                        }
+                       // Check for fory.BFloat16 (aliased to uint16)
+                       if elem == bfloat16Type {
+                               return bfloat16ArraySerializer{arrayType: 
type_}, nil
+                       }
                        return uint16ArraySerializer{arrayType: type_}, nil
                case reflect.Uint32:
                        return uint32ArraySerializer{arrayType: type_}, nil
diff --git a/go/fory/type_test.go b/go/fory/type_test.go
index 5b9766930..e8978e593 100644
--- a/go/fory/type_test.go
+++ b/go/fory/type_test.go
@@ -21,6 +21,7 @@ import (
        "reflect"
        "testing"
 
+       "github.com/apache/fory/go/fory/bfloat16"
        "github.com/apache/fory/go/fory/float16"
        "github.com/stretchr/testify/require"
 )
@@ -41,6 +42,9 @@ func TestTypeResolver(t *testing.T) {
                {reflect.TypeOf((*int)(nil)), "*int"},
                {reflect.TypeOf((*[10]int)(nil)), "*[10]int"},
                {reflect.TypeOf((*[10]int)(nil)).Elem(), "[10]int"},
+               {reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem(), 
"bfloat16.BFloat16"},
+               {reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem(), 
"[]bfloat16.BFloat16"},
+               {reflect.TypeOf((*[10]bfloat16.BFloat16)(nil)).Elem(), 
"[10]bfloat16.BFloat16"},
                {reflect.TypeOf((*[]map[string][]map[string]*any)(nil)).Elem(),
                        "[]map[string][]map[string]*interface {}"},
                {reflect.TypeOf((*A)(nil)), "*@example.A"},
@@ -84,6 +88,7 @@ func TestCreateSerializerSliceTypes(t *testing.T) {
                {reflect.TypeOf([]uint{}), 
reflect.TypeOf(uintSliceSerializer{})},
                {reflect.TypeOf([]uint16{}), 
reflect.TypeOf(uint16SliceSerializer{})},
                {reflect.TypeOf([]float16.Float16{}), 
reflect.TypeOf(float16SliceSerializer{})},
+               {reflect.TypeOf([]bfloat16.BFloat16{}), 
reflect.TypeOf(bfloat16SliceSerializer{})},
                {reflect.TypeOf([]uint32{}), 
reflect.TypeOf(uint32SliceSerializer{})},
                {reflect.TypeOf([]uint64{}), 
reflect.TypeOf(uint64SliceSerializer{})},
                {reflect.TypeOf([]string{}), 
reflect.TypeOf(stringSliceSerializer{})},
@@ -135,6 +140,7 @@ func TestCreateSerializerArrayTypes(t *testing.T) {
                {reflect.TypeOf([4]byte{}), 
reflect.TypeOf(uint8ArraySerializer{})},
                {reflect.TypeOf([4]uint16{}), 
reflect.TypeOf(uint16ArraySerializer{})},
                {reflect.TypeOf([4]float16.Float16{}), 
reflect.TypeOf(float16ArraySerializer{})},
+               {reflect.TypeOf([4]bfloat16.BFloat16{}), 
reflect.TypeOf(bfloat16ArraySerializer{})},
                {reflect.TypeOf([4]uint32{}), 
reflect.TypeOf(uint32ArraySerializer{})},
                {reflect.TypeOf([4]uint64{}), 
reflect.TypeOf(uint64ArraySerializer{})},
        }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to