zeroshade commented on code in PR #13768: URL: https://github.com/apache/arrow/pull/13768#discussion_r936712776
########## go/arrow/datatype_nested.go: ########## @@ -329,6 +333,208 @@ func (t *MapType) Layout() DataTypeLayout { return t.value.Layout() } +type ( + UnionTypeCode = int8 + UnionMode int8 +) + +const ( + MaxUnionTypeCode UnionTypeCode = 127 + InvalidUnionChildID int = -1 + + SparseMode UnionMode = iota + DenseMode +) + +type UnionType interface { + NestedType + Mode() UnionMode + ChildIDs() []int + TypeCodes() []UnionTypeCode + MaxTypeCode() UnionTypeCode +} + +func UnionOf(mode UnionMode, fields []Field, typeCodes []UnionTypeCode) UnionType { + switch mode { + case SparseMode: + return SparseUnionOf(fields, typeCodes) + case DenseMode: + return DenseUnionOf(fields, typeCodes) + default: + panic("arrow: invalid union mode") + } +} + +type unionType struct { + children []Field + typeCodes []UnionTypeCode + childIDs [int(MaxUnionTypeCode) + 1]int +} + +func (t *unionType) init(fields []Field, typeCodes []UnionTypeCode) { + // initialize all child IDs to -1 + t.childIDs[0] = InvalidUnionChildID + for i := 1; i < len(t.childIDs); i *= 2 { + copy(t.childIDs[i:], t.childIDs[:i]) + } + + t.children = fields + t.typeCodes = typeCodes + + for i, tc := range t.typeCodes { + t.childIDs[tc] = i + } +} + +func (t unionType) Fields() []Field { return t.children } +func (t unionType) TypeCodes() []UnionTypeCode { return t.typeCodes } +func (t unionType) ChildIDs() []int { return t.childIDs[:] } + +func (t *unionType) validate(fields []Field, typeCodes []UnionTypeCode, _ UnionMode) error { + if len(fields) != len(typeCodes) { + return errors.New("arrow: union types should have the same number of fields as type codes") + } + + for _, c := range typeCodes { + if c < 0 || c > MaxUnionTypeCode { + return errors.New("arrow: union type code out of bounds") + } + } + return nil +} + +func (t *unionType) MaxTypeCode() (max UnionTypeCode) { + if len(t.typeCodes) == 0 { + return + } + + max = t.typeCodes[0] + for _, c := range t.typeCodes[1:] { + if c > max { + max = c + } + } + return +} + +func (t *unionType) String() string { + var b strings.Builder + b.WriteByte('<') + for i := range t.typeCodes { + if i != 0 { + b.WriteString(", ") + } + fmt.Fprintf(&b, "%s=%d", t.children[i], t.typeCodes[i]) + } + b.WriteByte('>') + return b.String() +} + +func (t *unionType) fingerprint() string { + var b strings.Builder + for _, c := range t.typeCodes { + fmt.Fprintf(&b, ":%d", c) + } + b.WriteString("]{") + for _, c := range t.children { + fingerprint := c.Fingerprint() + if len(fingerprint) == 0 { + return "" + } + b.WriteString(fingerprint) + b.WriteByte(';') + } + b.WriteByte('}') + return b.String() +} + +func fieldsFromArrays(arrays []Array, names ...string) (ret []Field) { + ret = make([]Field, len(arrays)) + if len(names) == 0 { + for i, c := range arrays { + ret[i] = Field{Name: strconv.Itoa(i), Type: c.DataType(), Nullable: true} + } + } else { + debug.Assert(len(names) == len(arrays), "mismatch of arrays and names") + for i, c := range arrays { + ret[i] = Field{Name: names[i], Type: c.DataType(), Nullable: true} + } + } + return +} + +type SparseUnionType struct { + unionType +} + +func SparseUnionFromArrays(children []Array, fields []string, codes []UnionTypeCode) *SparseUnionType { + if len(codes) == 0 { + codes = make([]UnionTypeCode, len(children)) + for i := range children { + codes[i] = UnionTypeCode(i) + } + } + return SparseUnionOf(fieldsFromArrays(children, fields...), codes) +} + +func SparseUnionOf(fields []Field, typeCodes []UnionTypeCode) *SparseUnionType { + ret := &SparseUnionType{} + if err := ret.validate(fields, typeCodes, ret.Mode()); err != nil { + panic(err) + } + ret.init(fields, typeCodes) + return ret +} + +func (SparseUnionType) ID() Type { return SPARSE_UNION } +func (SparseUnionType) Name() string { return "sparse_union" } +func (SparseUnionType) Mode() UnionMode { return SparseMode } +func (t *SparseUnionType) Fingerprint() string { + return typeFingerprint(t) + "[s" + t.fingerprint() +} +func (SparseUnionType) Layout() DataTypeLayout { + return DataTypeLayout{Buffers: []BufferSpec{SpecAlwaysNull(), SpecFixedWidth(Uint8SizeBytes)}} Review Comment: Yup yup. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org