This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 1331381ae feat(go): add support for bfloat16 (#3310)
1331381ae is described below
commit 1331381ae968d792a8c7ab54736bb8c5a0084100
Author: Chang-Yen (Brian) Li <[email protected]>
AuthorDate: Wed Feb 11 15:39:28 2026 -0600
feat(go): add support for bfloat16 (#3310)
## Why?
We want to use `bfloat16` (BF16) in FDL to reduce payload size while
keeping a wide exponent range (common in ML/AI workflows). Fory
currently lacks a BF16 primitive and optimized arrays. This PR adds full
support for `bfloat16` in the Fory Go runtime and codegen.
## What does this PR do?
1. **Compiler**:
* Adds `bfloat16` to the IR type system (`PrimitiveKind`).
* Updates the Go generator to map `bfloat16` to `bfloat16.BFloat16` and
handle imports.
* Adds codegen tests for `bfloat16`.
2. **Go Runtime**:
* Adds a new package `go/fory/bfloat16` with a strong type `type
BFloat16 uint16`.
* Implements IEEE 754 compliant conversions between `float32` and
`bfloat16` (round-to-nearest, ties-to-even).
* Handles special values: NaN, Inf, ±0 correctly.
3. **Serialization**:
* Implements optimized serializers for `bfloat16`, `[]bfloat16`, and
`[N]bfloat16`, mirroring the `float16` implementation structure.
4. **Type System**:
* Updates `TypeResolver` to register `BFloat16`, `[]BFloat16`, and
`[N]BFloat16` types.
* Adds type resolution logic to distinguish `BFloat16` from `uint16`
(since it's an alias type) for correct serializer selection.
5. **Tests**: Adds comprehensive unit tests for conversions, rounding
logic, serialization, and codegen.
## Related issues
Fixes #3284
## Does this PR introduce any user-facing change?
- [x] Does this PR introduce any public API change?
- [x] Does this PR introduce any binary protocol compatibility change?
## Benchmark
---
compiler/fory_compiler/generators/go.py | 3 +
compiler/fory_compiler/ir/types.py | 2 +
.../fory_compiler/tests/test_generated_code.py | 30 +++++++
go/fory/array_primitive.go | 75 +++++++++++++++++
go/fory/array_primitive_test.go | 16 ++++
go/fory/bfloat16/bfloat16.go | 72 ++++++++++++++++
go/fory/bfloat16/bfloat16_test.go | 96 ++++++++++++++++++++++
go/fory/primitive.go | 53 ++++++++++++
go/fory/primitive_test.go | 34 ++++++++
go/fory/skip.go | 4 +-
go/fory/slice_primitive.go | 79 ++++++++++++++++++
go/fory/slice_primitive_test.go | 43 ++++++++++
go/fory/type_resolver.go | 14 ++++
go/fory/type_test.go | 6 ++
14 files changed, 526 insertions(+), 1 deletion(-)
diff --git a/compiler/fory_compiler/generators/go.py
b/compiler/fory_compiler/generators/go.py
index d7784a0df..0fd9d4b0c 100644
--- a/compiler/fory_compiler/generators/go.py
+++ b/compiler/fory_compiler/generators/go.py
@@ -190,6 +190,7 @@ class GoGenerator(BaseGenerator):
PrimitiveKind.VAR_UINT64: "uint64",
PrimitiveKind.TAGGED_UINT64: "uint64",
PrimitiveKind.FLOAT16: "float16.Float16",
+ PrimitiveKind.BFLOAT16: "bfloat16.BFloat16",
PrimitiveKind.FLOAT32: "float32",
PrimitiveKind.FLOAT64: "float64",
PrimitiveKind.STRING: "string",
@@ -1090,6 +1091,8 @@ class GoGenerator(BaseGenerator):
imports.add('"time"')
elif field_type.kind == PrimitiveKind.FLOAT16:
imports.add('float16 "github.com/apache/fory/go/fory/float16"')
+ elif field_type.kind == PrimitiveKind.BFLOAT16:
+ imports.add('bfloat16
"github.com/apache/fory/go/fory/bfloat16"')
elif isinstance(field_type, ListType):
self.collect_imports(field_type.element_type, imports)
diff --git a/compiler/fory_compiler/ir/types.py
b/compiler/fory_compiler/ir/types.py
index 3dfc3d8ed..facc95ef6 100644
--- a/compiler/fory_compiler/ir/types.py
+++ b/compiler/fory_compiler/ir/types.py
@@ -39,6 +39,7 @@ class PrimitiveKind(PyEnum):
VAR_UINT64 = "var_uint64"
TAGGED_UINT64 = "tagged_uint64"
FLOAT16 = "float16"
+ BFLOAT16 = "bfloat16"
FLOAT32 = "float32"
FLOAT64 = "float64"
STRING = "string"
@@ -67,6 +68,7 @@ PRIMITIVE_TYPES = {
"fixed_uint64": PrimitiveKind.UINT64,
"tagged_uint64": PrimitiveKind.TAGGED_UINT64,
"float16": PrimitiveKind.FLOAT16,
+ "bfloat16": PrimitiveKind.BFLOAT16,
"float32": PrimitiveKind.FLOAT32,
"float64": PrimitiveKind.FLOAT64,
"string": PrimitiveKind.STRING,
diff --git a/compiler/fory_compiler/tests/test_generated_code.py
b/compiler/fory_compiler/tests/test_generated_code.py
index 7fda28f48..13dc99eab 100644
--- a/compiler/fory_compiler/tests/test_generated_code.py
+++ b/compiler/fory_compiler/tests/test_generated_code.py
@@ -497,3 +497,33 @@ def test_generated_code_tree_ref_options_equivalent():
cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator))
assert "SharedWeak<TreeNode>" in cpp_output
+
+
+def test_go_bfloat16_generation():
+ idl = dedent(
+ """
+ package bfloat16_test;
+
+ message BFloat16Message {
+ bfloat16 val = 1;
+ optional bfloat16 opt_val = 2;
+ list<bfloat16> list_val = 3;
+ }
+ """
+ )
+ schema = parse_fdl(idl)
+ files = generate_files(schema, GoGenerator)
+
+ assert len(files) == 1
+ content = list(files.values())[0]
+
+ # Check imports
+ assert 'bfloat16 "github.com/apache/fory/go/fory/bfloat16"' in content
+
+ # Check fields
+ assert '\tVal bfloat16.BFloat16 `fory:"id=1"`' in content
+ assert (
+ '\tOptVal optional.Optional[bfloat16.BFloat16] `fory:"id=2,nullable"`'
+ in content
+ )
+ assert "\tListVal []bfloat16.BFloat16" in content
diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go
index 977773359..27813060b 100644
--- a/go/fory/array_primitive.go
+++ b/go/fory/array_primitive.go
@@ -21,6 +21,7 @@ import (
"reflect"
"unsafe"
+ "github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
)
@@ -871,3 +872,77 @@ func (s float16ArraySerializer) Read(ctx *ReadContext,
refMode RefMode, readType
func (s float16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode
RefMode, typeInfo *TypeInfo, value reflect.Value) {
s.Read(ctx, refMode, false, false, value)
}
+
+// ============================================================================
+// bfloat16ArraySerializer - optimized [N]bfloat16.BFloat16 serialization
+// ============================================================================
+
+type bfloat16ArraySerializer struct {
+ arrayType reflect.Type
+}
+
+func (s bfloat16ArraySerializer) WriteData(ctx *WriteContext, value
reflect.Value) {
+ buf := ctx.Buffer()
+ length := value.Len()
+ size := length * 2
+ buf.WriteLength(size)
+ if length > 0 {
+ if value.CanAddr() && isLittleEndian {
+ ptr := value.Addr().UnsafePointer()
+ buf.WriteBinary(unsafe.Slice((*byte)(ptr), size))
+ } else {
+ for i := 0; i < length; i++ {
+ // We can't easily cast the whole array if not
addressable/little-endian
+ // So we iterate.
+ val :=
value.Index(i).Interface().(bfloat16.BFloat16)
+ buf.WriteUint16(val.Bits())
+ }
+ }
+ }
+}
+
+func (s bfloat16ArraySerializer) Write(ctx *WriteContext, refMode RefMode,
writeType bool, hasGenerics bool, value reflect.Value) {
+ writeArrayRefAndType(ctx, refMode, writeType, value, BFLOAT16_ARRAY)
+ if ctx.HasError() {
+ return
+ }
+ s.WriteData(ctx, value)
+}
+
+func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
+ buf := ctx.Buffer()
+ ctxErr := ctx.Err()
+ size := buf.ReadLength(ctxErr)
+ length := size / 2
+ if ctx.HasError() {
+ return
+ }
+ if length != value.Type().Len() {
+ ctx.SetError(DeserializationErrorf("array length %d does not
match type %v", length, value.Type()))
+ return
+ }
+
+ if length > 0 {
+ if isLittleEndian {
+ ptr := value.Addr().UnsafePointer()
+ raw := buf.ReadBinary(size, ctxErr)
+ copy(unsafe.Slice((*byte)(ptr), size), raw)
+ } else {
+ for i := 0; i < length; i++ {
+
value.Index(i).Set(reflect.ValueOf(bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr))))
+ }
+ }
+ }
+}
+
+func (s bfloat16ArraySerializer) Read(ctx *ReadContext, refMode RefMode,
readType bool, hasGenerics bool, value reflect.Value) {
+ done := readArrayRefAndType(ctx, refMode, readType, value)
+ if done || ctx.HasError() {
+ return
+ }
+ s.ReadData(ctx, value)
+}
+
+func (s bfloat16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode
RefMode, typeInfo *TypeInfo, value reflect.Value) {
+ s.Read(ctx, refMode, false, false, value)
+}
diff --git a/go/fory/array_primitive_test.go b/go/fory/array_primitive_test.go
index c2e684af4..e8c99fb2f 100644
--- a/go/fory/array_primitive_test.go
+++ b/go/fory/array_primitive_test.go
@@ -20,6 +20,7 @@ package fory
import (
"testing"
+ "github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -74,6 +75,21 @@ func TestPrimitiveArraySerializer(t *testing.T) {
require.NoError(t, err)
require.Equal(t, arr, result)
})
+
+ t.Run("bfloat16_array", func(t *testing.T) {
+ arr := [3]bfloat16.BFloat16{
+ bfloat16.BFloat16FromFloat32(1.0),
+ bfloat16.BFloat16FromFloat32(2.5),
+ bfloat16.BFloat16FromFloat32(-3.5),
+ }
+ data, err := f.Serialize(arr)
+ assert.NoError(t, err)
+
+ var result [3]bfloat16.BFloat16
+ err = f.Deserialize(data, &result)
+ assert.NoError(t, err)
+ assert.Equal(t, arr, result)
+ })
}
func TestArraySliceInteroperability(t *testing.T) {
diff --git a/go/fory/bfloat16/bfloat16.go b/go/fory/bfloat16/bfloat16.go
new file mode 100644
index 000000000..6b31ffdb9
--- /dev/null
+++ b/go/fory/bfloat16/bfloat16.go
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bfloat16
+
+import (
+ "fmt"
+ "math"
+)
+
+// BFloat16 represents a brain floating point number (bfloat16).
+// It is stored as a uint16.
+type BFloat16 uint16
+
+// BFloat16FromBits returns the BFloat16 corresponding to the given bit
pattern.
+func BFloat16FromBits(b uint16) BFloat16 {
+ return BFloat16(b)
+}
+
+// Bits returns the raw bit pattern of the floating point number.
+func (f BFloat16) Bits() uint16 {
+ return uint16(f)
+}
+
+// BFloat16FromFloat32 converts a float32 to a BFloat16.
+// Rounds to nearest, ties to even.
+func BFloat16FromFloat32(f float32) BFloat16 {
+ u := math.Float32bits(f)
+
+ // NaN check
+ if (u&0x7F800000) == 0x7F800000 && (u&0x007FFFFF) != 0 {
+ return BFloat16(0x7FC0) // Canonical NaN
+ }
+
+ // Fast path for rounding
+ // We want to add a rounding bias and then truncate.
+ // For ties-to-even:
+ // If LSB of result (bit 16) is 0: Rounding bias is 0x7FFF
+ // If LSB of result (bit 16) is 1: Rounding bias is 0x8000
+ // lsb is (u >> 16) & 1.
+ // bias = 0x7FFF + lsb
+
+ lsb := (u >> 16) & 1
+ roundingBias := uint32(0x7FFF) + lsb
+ u += roundingBias
+ return BFloat16(u >> 16)
+}
+
+// Float32 returns the float32 representation of the BFloat16.
+func (f BFloat16) Float32() float32 {
+ // Just shift left by 16 bits
+ return math.Float32frombits(uint32(f) << 16)
+}
+
+// String returns the string representation of f.
+func (f BFloat16) String() string {
+ return fmt.Sprintf("%g", f.Float32())
+}
diff --git a/go/fory/bfloat16/bfloat16_test.go
b/go/fory/bfloat16/bfloat16_test.go
new file mode 100644
index 000000000..44745ae91
--- /dev/null
+++ b/go/fory/bfloat16/bfloat16_test.go
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bfloat16_test
+
+import (
+ "math"
+ "testing"
+
+ "github.com/apache/fory/go/fory/bfloat16"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBFloat16_Conversion(t *testing.T) {
+ tests := []struct {
+ name string
+ f32 float32
+ want uint16 // bits
+ check bool // if true, check exact bits
+ }{
+ {"Zero", 0.0, 0x0000, true},
+ {"NegZero", float32(math.Copysign(0, -1)), 0x8000, true},
+ {"One", 1.0, 0x3F80, true},
+ {"MinusOne", -1.0, 0xBF80, true},
+ {"Inf", float32(math.Inf(1)), 0x7F80, true},
+ {"NegInf", float32(math.Inf(-1)), 0xFF80, true},
+ // 1.5 -> 0x3FC0. (0x3FC00000 is 1.5)
+ {"OnePointFive", 1.5, 0x3FC0, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ bf16 := bfloat16.BFloat16FromFloat32(tt.f32)
+ if tt.check {
+ assert.Equal(t, tt.want, bf16.Bits(), "Bits
match")
+ }
+
+ // Round trip check
+ roundTrip := bf16.Float32()
+ if math.IsInf(float64(tt.f32), 0) {
+ assert.True(t, math.IsInf(float64(roundTrip),
0))
+ assert.Equal(t, math.Signbit(float64(tt.f32)),
math.Signbit(float64(roundTrip)))
+ } else if math.IsNaN(float64(tt.f32)) {
+ assert.True(t, math.IsNaN(float64(roundTrip)))
+ } else {
+ if tt.check {
+ assert.Equal(t, tt.f32, roundTrip,
"Round trip value match")
+ }
+ }
+ })
+ }
+}
+
+func TestBFloat16_Rounding(t *testing.T) {
+ // BFloat16 has 7 bits of mantissa. For 1.0, ULP is 2^-7, and half ULP
is 2^-8.
+ // Values are rounded to nearest even. 1.0 + 2^-8 should round to 1.0
(even mantissa).
+
+ // The float32 representation of 1.0 is 0x3F800000.
+ // Adding 2^-8 (1/256) means setting bit 15 (23-8).
+ // So, 1.0 + 2^-8 in float32 is 0x3F808000.
+ val1 := math.Float32frombits(0x3F808000) // 1.0 + 2^-8
+ bf1 := bfloat16.BFloat16FromFloat32(val1)
+ assert.Equal(t, uint16(0x3F80), bf1.Bits(), "Round to even (down)")
+
+ // For 1.0 + 3 * 2^-8 (1.5 ULP), bits 15 and 14 are set,
+ // making the float32 representation 0x3F80C000. This rounds up.
+ val2 := math.Float32frombits(0x3F80C000)
+ bf2 := bfloat16.BFloat16FromFloat32(val2)
+ assert.Equal(t, uint16(0x3F81), bf2.Bits(), "Round up")
+
+ // 1.0 + 2^-7 is the next representable number after 1.0. In float32,
this is 0x3F810000.
+ val3 := math.Float32frombits(0x3F810000)
+ bf3 := bfloat16.BFloat16FromFloat32(val3)
+ assert.Equal(t, uint16(0x3F81), bf3.Bits(), "Exact")
+
+ // For 1.0 + 2^-7 + 2^-8 (0x3F818000), the LSB (bit 16) of 0x3F81 is 1
(odd),
+ // and the guard bit (bit 15) is 1. Rounding to nearest even means
rounding up.
+ // Result: 0x3F82.
+ val4 := math.Float32frombits(0x3F818000)
+ bf4 := bfloat16.BFloat16FromFloat32(val4)
+ assert.Equal(t, uint16(0x3F82), bf4.Bits(), "Round to even (up)")
+}
diff --git a/go/fory/primitive.go b/go/fory/primitive.go
index 2c316d549..998879042 100644
--- a/go/fory/primitive.go
+++ b/go/fory/primitive.go
@@ -663,3 +663,56 @@ func (s float16Serializer) Read(ctx *ReadContext, refMode
RefMode, readType bool
func (s float16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode,
typeInfo *TypeInfo, value reflect.Value) {
s.Read(ctx, refMode, false, false, value)
}
+
+// ============================================================================
+// bfloat16Serializer - optimized bfloat16 serialization
+// ============================================================================
+
+// bfloat16Serializer handles bfloat16 type
+type bfloat16Serializer struct{}
+
+var globalBFloat16Serializer = bfloat16Serializer{}
+
+func (s bfloat16Serializer) WriteData(ctx *WriteContext, value reflect.Value) {
+ // Value is effectively uint16 (alias)
+ ctx.buffer.WriteUint16(uint16(value.Uint()))
+}
+
+func (s bfloat16Serializer) Write(ctx *WriteContext, refMode RefMode,
writeType bool, hasGenerics bool, value reflect.Value) {
+ if refMode != RefModeNone {
+ ctx.buffer.WriteInt8(NotNullValueFlag)
+ }
+ if writeType {
+ ctx.buffer.WriteUint8(uint8(BFLOAT16))
+ }
+ s.WriteData(ctx, value)
+}
+
+func (s bfloat16Serializer) ReadData(ctx *ReadContext, value reflect.Value) {
+ err := ctx.Err()
+ bits := ctx.buffer.ReadUint16(err)
+ if ctx.HasError() {
+ return
+ }
+ value.SetUint(uint64(bits))
+}
+
+func (s bfloat16Serializer) Read(ctx *ReadContext, refMode RefMode, readType
bool, hasGenerics bool, value reflect.Value) {
+ err := ctx.Err()
+ if refMode != RefModeNone {
+ if ctx.buffer.ReadInt8(err) == NullFlag {
+ return
+ }
+ }
+ if readType {
+ _ = ctx.buffer.ReadUint8(err)
+ }
+ if ctx.HasError() {
+ return
+ }
+ s.ReadData(ctx, value)
+}
+
+func (s bfloat16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode
RefMode, typeInfo *TypeInfo, value reflect.Value) {
+ s.Read(ctx, refMode, false, false, value)
+}
diff --git a/go/fory/primitive_test.go b/go/fory/primitive_test.go
index 978a81d46..0f478590e 100644
--- a/go/fory/primitive_test.go
+++ b/go/fory/primitive_test.go
@@ -20,6 +20,7 @@ package fory
import (
"testing"
+ "github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
"github.com/stretchr/testify/require"
)
@@ -56,3 +57,36 @@ func TestFloat16PrimitiveSliceDirect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, slice, resSlice)
}
+
+func TestBFloat16Primitive(t *testing.T) {
+ f := New(WithXlang(true))
+ bf16 := bfloat16.BFloat16FromFloat32(3.14)
+
+ // Directly serialize a bfloat16 value
+ data, err := f.Serialize(bf16)
+ require.NoError(t, err)
+
+ var res bfloat16.BFloat16
+ err = f.Deserialize(data, &res)
+ require.NoError(t, err)
+
+ require.Equal(t, bf16.Bits(), res.Bits())
+
+ // Value check (approximate because BF16 precision is low)
+ require.InDelta(t, 3.14, res.Float32(), 0.1)
+}
+
+func TestBFloat16PrimitiveSliceDirect(t *testing.T) {
+ // Tests serializing a slice as a root object
+ f := New(WithXlang(true))
+ bf16 := bfloat16.BFloat16FromFloat32(3.14)
+
+ slice := []bfloat16.BFloat16{bf16, bfloat16.BFloat16(0)}
+ data, err := f.Serialize(slice)
+ require.NoError(t, err)
+
+ var resSlice []bfloat16.BFloat16
+ err = f.Deserialize(data, &resSlice)
+ require.NoError(t, err)
+ require.Equal(t, slice, resSlice)
+}
diff --git a/go/fory/skip.go b/go/fory/skip.go
index b0b36c5af..34005ad74 100644
--- a/go/fory/skip.go
+++ b/go/fory/skip.go
@@ -576,6 +576,8 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef,
readRefFlag bool, isField bo
_ = ctx.buffer.ReadVarint64(err)
// Floating point types
+ case BFLOAT16, FLOAT16:
+ _ = ctx.buffer.ReadUint16(err)
case FLOAT32:
_ = ctx.buffer.ReadFloat32(err)
case FLOAT64:
@@ -610,7 +612,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef,
readRefFlag bool, isField bo
return
}
_ = ctx.buffer.ReadBinary(length, err)
- case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY:
+ case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY:
length := ctx.buffer.ReadLength(err)
if ctx.HasError() {
return
diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go
index 200d144c6..e4daf990b 100644
--- a/go/fory/slice_primitive.go
+++ b/go/fory/slice_primitive.go
@@ -22,6 +22,7 @@ import (
"strconv"
"unsafe"
+ "github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
)
@@ -1330,3 +1331,81 @@ func ReadStringSlice(buf *ByteBuffer, err *Error)
[]string {
}
return result
}
+
+// ============================================================================
+// bfloat16SliceSerializer - optimized []bfloat16.BFloat16 serialization
+// ============================================================================
+
+type bfloat16SliceSerializer struct{}
+
+func (s bfloat16SliceSerializer) WriteData(ctx *WriteContext, value
reflect.Value) {
+ // Cast to []bfloat16.BFloat16
+ v := value.Interface().([]bfloat16.BFloat16)
+ buf := ctx.Buffer()
+ length := len(v)
+ size := length * 2
+ buf.WriteLength(size)
+ if length > 0 {
+ ptr := unsafe.Pointer(&v[0])
+ if isLittleEndian {
+ buf.WriteBinary(unsafe.Slice((*byte)(ptr), size))
+ } else {
+ for i := 0; i < length; i++ {
+ buf.WriteUint16(v[i].Bits())
+ }
+ }
+ }
+}
+
+func (s bfloat16SliceSerializer) Write(ctx *WriteContext, refMode RefMode,
writeType bool, hasGenerics bool, value reflect.Value) {
+ done := writeSliceRefAndType(ctx, refMode, writeType, value,
BFLOAT16_ARRAY)
+ if done || ctx.HasError() {
+ return
+ }
+ s.WriteData(ctx, value)
+}
+
+func (s bfloat16SliceSerializer) Read(ctx *ReadContext, refMode RefMode,
readType bool, hasGenerics bool, value reflect.Value) {
+ done, typeId := readSliceRefAndType(ctx, refMode, readType, value)
+ if done || ctx.HasError() {
+ return
+ }
+ if readType && typeId != uint32(BFLOAT16_ARRAY) {
+ ctx.SetError(DeserializationErrorf("slice type mismatch:
expected BFLOAT16_ARRAY (%d), got %d", BFLOAT16_ARRAY, typeId))
+ return
+ }
+ s.ReadData(ctx, value)
+}
+
+func (s bfloat16SliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode
RefMode, typeInfo *TypeInfo, value reflect.Value) {
+ s.Read(ctx, refMode, false, false, value)
+}
+
+func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
+ buf := ctx.Buffer()
+ ctxErr := ctx.Err()
+ size := buf.ReadLength(ctxErr)
+ length := size / 2
+ if ctx.HasError() {
+ return
+ }
+
+ ptr := (*[]bfloat16.BFloat16)(value.Addr().UnsafePointer())
+ if length == 0 {
+ *ptr = make([]bfloat16.BFloat16, 0)
+ return
+ }
+
+ result := make([]bfloat16.BFloat16, length)
+
+ if isLittleEndian {
+ raw := buf.ReadBinary(size, ctxErr)
+ targetPtr := unsafe.Pointer(&result[0])
+ copy(unsafe.Slice((*byte)(targetPtr), size), raw)
+ } else {
+ for i := 0; i < length; i++ {
+ result[i] =
bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr))
+ }
+ }
+ *ptr = result
+}
diff --git a/go/fory/slice_primitive_test.go b/go/fory/slice_primitive_test.go
index 82959d0da..e03a4cb00 100644
--- a/go/fory/slice_primitive_test.go
+++ b/go/fory/slice_primitive_test.go
@@ -21,6 +21,7 @@ import (
"math"
"testing"
+ "github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
"github.com/stretchr/testify/assert"
)
@@ -67,6 +68,48 @@ func TestFloat16Slice(t *testing.T) {
})
}
+func TestBFloat16Slice(t *testing.T) {
+ f := NewFory()
+
+ t.Run("bfloat16_slice", func(t *testing.T) {
+ slice := []bfloat16.BFloat16{
+ bfloat16.BFloat16FromFloat32(1.0),
+ bfloat16.BFloat16FromFloat32(2.5),
+ bfloat16.BFloat16FromFloat32(-3.5),
+ }
+ data, err := f.Serialize(slice)
+ assert.NoError(t, err)
+
+ var result []bfloat16.BFloat16
+ err = f.Deserialize(data, &result)
+ assert.NoError(t, err)
+ assert.Equal(t, slice, result)
+ })
+
+ t.Run("bfloat16_slice_empty", func(t *testing.T) {
+ slice := []bfloat16.BFloat16{}
+ data, err := f.Serialize(slice)
+ assert.NoError(t, err)
+
+ var result []bfloat16.BFloat16
+ err = f.Deserialize(data, &result)
+ assert.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.Empty(t, result)
+ })
+
+ t.Run("bfloat16_slice_nil", func(t *testing.T) {
+ var slice []bfloat16.BFloat16 = nil
+ data, err := f.Serialize(slice)
+ assert.NoError(t, err)
+
+ var result []bfloat16.BFloat16
+ err = f.Deserialize(data, &result)
+ assert.NoError(t, err)
+ assert.Nil(t, result)
+ })
+}
+
func TestIntSlice(t *testing.T) {
f := NewFory()
diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go
index b10a134d8..19cd83997 100644
--- a/go/fory/type_resolver.go
+++ b/go/fory/type_resolver.go
@@ -29,6 +29,7 @@ import (
"time"
"unsafe"
+ "github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
"github.com/apache/fory/go/fory/meta"
)
@@ -77,6 +78,7 @@ var (
float32SliceType = reflect.TypeOf((*[]float32)(nil)).Elem()
float64SliceType = reflect.TypeOf((*[]float64)(nil)).Elem()
float16SliceType = reflect.TypeOf((*[]float16.Float16)(nil)).Elem()
+ bfloat16SliceType =
reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem()
interfaceSliceType = reflect.TypeOf((*[]any)(nil)).Elem()
interfaceMapType = reflect.TypeOf((*map[any]any)(nil)).Elem()
stringStringMapType = reflect.TypeOf((*map[string]string)(nil)).Elem()
@@ -103,6 +105,7 @@ var (
float32Type = reflect.TypeOf((*float32)(nil)).Elem()
float64Type = reflect.TypeOf((*float64)(nil)).Elem()
float16Type = reflect.TypeOf((*float16.Float16)(nil)).Elem()
+ bfloat16Type = reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem()
dateType = reflect.TypeOf((*Date)(nil)).Elem()
timestampType = reflect.TypeOf((*time.Time)(nil)).Elem()
genericSetType = reflect.TypeOf((*Set[any])(nil)).Elem()
@@ -259,6 +262,7 @@ func newTypeResolver(fory *Fory) *TypeResolver {
float32Type,
float64Type,
float16Type,
+ bfloat16Type,
stringType,
dateType,
timestampType,
@@ -416,6 +420,7 @@ func (r *TypeResolver) initialize() {
{float32SliceType, FLOAT32_ARRAY, float32SliceSerializer{}},
{float64SliceType, FLOAT64_ARRAY, float64SliceSerializer{}},
{float16SliceType, FLOAT16_ARRAY, float16SliceSerializer{}},
+ {bfloat16SliceType, BFLOAT16_ARRAY, bfloat16SliceSerializer{}},
// Register common map types for fast path with optimized
serializers
{stringStringMapType, MAP, stringStringMapSerializer{}},
{stringInt64MapType, MAP, stringInt64MapSerializer{}},
@@ -440,6 +445,7 @@ func (r *TypeResolver) initialize() {
{float32Type, FLOAT32, float32Serializer{}},
{float64Type, FLOAT64, float64Serializer{}},
{float16Type, FLOAT16, float16Serializer{}},
+ {bfloat16Type, BFLOAT16, bfloat16Serializer{}},
{dateType, DATE, dateSerializer{}},
{timestampType, TIMESTAMP, timeSerializer{}},
{genericSetType, SET, setSerializer{}},
@@ -1712,6 +1718,10 @@ func (r *TypeResolver) createSerializer(type_
reflect.Type, mapInStruct bool) (s
if elem == float16Type {
return float16SliceSerializer{}, nil
}
+ // Check for fory.BFloat16 (aliased to uint16)
+ if elem == bfloat16Type {
+ return bfloat16SliceSerializer{}, nil
+ }
return uint16SliceSerializer{}, nil
case reflect.Uint32:
return uint32SliceSerializer{}, nil
@@ -1763,6 +1773,10 @@ func (r *TypeResolver) createSerializer(type_
reflect.Type, mapInStruct bool) (s
if elem == float16Type {
return float16ArraySerializer{arrayType:
type_}, nil
}
+ // Check for fory.BFloat16 (aliased to uint16)
+ if elem == bfloat16Type {
+ return bfloat16ArraySerializer{arrayType:
type_}, nil
+ }
return uint16ArraySerializer{arrayType: type_}, nil
case reflect.Uint32:
return uint32ArraySerializer{arrayType: type_}, nil
diff --git a/go/fory/type_test.go b/go/fory/type_test.go
index 5b9766930..e8978e593 100644
--- a/go/fory/type_test.go
+++ b/go/fory/type_test.go
@@ -21,6 +21,7 @@ import (
"reflect"
"testing"
+ "github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
"github.com/stretchr/testify/require"
)
@@ -41,6 +42,9 @@ func TestTypeResolver(t *testing.T) {
{reflect.TypeOf((*int)(nil)), "*int"},
{reflect.TypeOf((*[10]int)(nil)), "*[10]int"},
{reflect.TypeOf((*[10]int)(nil)).Elem(), "[10]int"},
+ {reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem(),
"bfloat16.BFloat16"},
+ {reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem(),
"[]bfloat16.BFloat16"},
+ {reflect.TypeOf((*[10]bfloat16.BFloat16)(nil)).Elem(),
"[10]bfloat16.BFloat16"},
{reflect.TypeOf((*[]map[string][]map[string]*any)(nil)).Elem(),
"[]map[string][]map[string]*interface {}"},
{reflect.TypeOf((*A)(nil)), "*@example.A"},
@@ -84,6 +88,7 @@ func TestCreateSerializerSliceTypes(t *testing.T) {
{reflect.TypeOf([]uint{}),
reflect.TypeOf(uintSliceSerializer{})},
{reflect.TypeOf([]uint16{}),
reflect.TypeOf(uint16SliceSerializer{})},
{reflect.TypeOf([]float16.Float16{}),
reflect.TypeOf(float16SliceSerializer{})},
+ {reflect.TypeOf([]bfloat16.BFloat16{}),
reflect.TypeOf(bfloat16SliceSerializer{})},
{reflect.TypeOf([]uint32{}),
reflect.TypeOf(uint32SliceSerializer{})},
{reflect.TypeOf([]uint64{}),
reflect.TypeOf(uint64SliceSerializer{})},
{reflect.TypeOf([]string{}),
reflect.TypeOf(stringSliceSerializer{})},
@@ -135,6 +140,7 @@ func TestCreateSerializerArrayTypes(t *testing.T) {
{reflect.TypeOf([4]byte{}),
reflect.TypeOf(uint8ArraySerializer{})},
{reflect.TypeOf([4]uint16{}),
reflect.TypeOf(uint16ArraySerializer{})},
{reflect.TypeOf([4]float16.Float16{}),
reflect.TypeOf(float16ArraySerializer{})},
+ {reflect.TypeOf([4]bfloat16.BFloat16{}),
reflect.TypeOf(bfloat16ArraySerializer{})},
{reflect.TypeOf([4]uint32{}),
reflect.TypeOf(uint32ArraySerializer{})},
{reflect.TypeOf([4]uint64{}),
reflect.TypeOf(uint64ArraySerializer{})},
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]