This is an automated email from the ASF dual-hosted git repository.
mappjzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-devlake.git
The following commit(s) were added to refs/heads/main by this push:
new 04bb9d42 refactor: reflect for get primary key (#2422)
04bb9d42 is described below
commit 04bb9d42018001db51983fddcb374b7bbe18fefa
Author: mappjzc <[email protected]>
AuthorDate: Thu Jul 7 16:22:51 2022 +0800
refactor: reflect for get primary key (#2422)
* refactor: reflect for get primary key
Add GetPrimaryKeyNameFromDB to dal
Add GetPrimaryKey to dal
Add WalkFiled to utils
Nddtfjiang <[email protected]>
* refactor: rename get primarykey
Rename GetPrimaryKeyNameFromDB to GetPrimarykeyColumnNames.
Rename GetPrimaryKey to GetPrimarykeyFields.
Nddtfjiang <[email protected]>
---
helpers/e2ehelper/data_flow_tester.go | 75 ++++++++++------------
helpers/e2ehelper/data_flow_tester_test.go | 21 +++---
impl/dalgorm/dalgorm.go | 70 ++++++++------------
models/domainlayer/didgen/domain_id_generator.go | 44 ++++---------
plugins/core/dal/dal.go | 57 +++++++++++++++-
plugins/helper/batch_save.go | 53 +++++----------
plugins/helper/batch_save_divider_test.go | 10 ++-
plugins/helper/config_util.go | 45 ++++---------
plugins/starrocks/tasks.go | 50 +++++++++++----
plugins/starrocks/utils.go => utils/structfield.go | 37 ++++++-----
10 files changed, 228 insertions(+), 234 deletions(-)
diff --git a/helpers/e2ehelper/data_flow_tester.go
b/helpers/e2ehelper/data_flow_tester.go
index 28c33875..26cdca16 100644
--- a/helpers/e2ehelper/data_flow_tester.go
+++ b/helpers/e2ehelper/data_flow_tester.go
@@ -22,6 +22,13 @@ import (
"database/sql"
"encoding/json"
"fmt"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
"github.com/apache/incubator-devlake/config"
"github.com/apache/incubator-devlake/helpers/pluginhelper"
"github.com/apache/incubator-devlake/impl/dalgorm"
@@ -35,12 +42,6 @@ import (
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"gorm.io/gorm/schema"
- "os"
- "strconv"
- "strings"
- "sync"
- "testing"
- "time"
)
// DataFlowTester provides a universal data integrity validation facility to
help `Plugin` verifying records between
@@ -186,14 +187,7 @@ func (t *DataFlowTester) Subtask(subtaskMeta
core.SubTaskMeta, taskData interfac
}
}
-func (t *DataFlowTester) getPkFields(dst schema.Tabler) []string {
- return t.getFields(dst, func(column gorm.ColumnType) bool {
- isPk, _ := column.PrimaryKey()
- return isPk
- })
-}
-
-func filterColumn(column gorm.ColumnType, opts TableOptions) bool {
+func filterColumn(column dal.ColumnMeta, opts TableOptions) bool {
for _, ignore := range opts.IgnoreFields {
if column.Name() == ignore {
return false
@@ -212,31 +206,21 @@ func filterColumn(column gorm.ColumnType, opts
TableOptions) bool {
return targetFound
}
-func (t *DataFlowTester) getFields(dst schema.Tabler, filter func(column
gorm.ColumnType) bool) []string {
- columnTypes, err := t.Db.Migrator().ColumnTypes(dst)
- var fields []string
- if err != nil {
- panic(err)
- }
- for _, columnType := range columnTypes {
- if filter == nil || filter(columnType) {
- fields = append(fields, columnType.Name())
- }
- }
- return fields
-}
-
// CreateSnapshot reads rows from database and write them into .csv file.
func (t *DataFlowTester) CreateSnapshot(dst schema.Tabler, opts TableOptions) {
location, _ := time.LoadLocation(`UTC`)
- pkFields := t.getPkFields(dst)
+
targetFields := t.resolveTargetFields(dst, opts)
- allFields := append(pkFields, targetFields...)
+ pkColumnNames, err := dal.GetPrimarykeyColumnNames(t.Dal, dst)
+ if err != nil {
+ panic(err)
+ }
+ allFields := append(pkColumnNames, targetFields...)
allFields = utils.StringsUniq(allFields)
dbCursor, err := t.Dal.Cursor(
dal.Select(strings.Join(allFields, `,`)),
dal.From(dst.TableName()),
- dal.Orderby(strings.Join(pkFields, `,`)),
+ dal.Orderby(strings.Join(pkColumnNames, `,`)),
)
if err != nil {
panic(fmt.Errorf("unable to run select query on table %s: %v",
dst.TableName(), err))
@@ -369,9 +353,13 @@ func (t *DataFlowTester) resolveTargetFields(dst
schema.Tabler, opts TableOption
}
var targetFields []string
if len(opts.TargetFields) == 0 || len(opts.IgnoreFields) > 0 {
- targetFields = append(targetFields, t.getFields(dst,
func(column gorm.ColumnType) bool {
- return filterColumn(column, opts)
- })...)
+ names, err := dal.GetColumnNames(t.Dal, dst, func(cm
dal.ColumnMeta) bool {
+ return filterColumn(cm, opts)
+ })
+ if err != nil {
+ panic(err)
+ }
+ targetFields = append(targetFields, names...)
} else {
targetFields = opts.TargetFields
}
@@ -388,17 +376,22 @@ func (t *DataFlowTester) VerifyTableWithOptions(dst
schema.Tabler, opts TableOpt
t.CreateSnapshot(dst, opts)
return
}
+
targetFields := t.resolveTargetFields(dst, opts)
- pkFields := t.getPkFields(dst)
+ pkColumns, err := dal.GetPrimarykeyColumns(t.Dal, dst)
+ if err != nil {
+ panic(err)
+ }
+
csvIter := pluginhelper.NewCsvFileIterator(opts.CSVRelPath)
defer csvIter.Close()
expectedTotal := int64(0)
csvMap := map[string]map[string]interface{}{}
for csvIter.HasNext() {
expected := csvIter.Fetch()
- pkValues := make([]string, 0, len(pkFields))
- for _, pkf := range pkFields {
- pkValues = append(pkValues, expected[pkf].(string))
+ pkValues := make([]string, 0, len(pkColumns))
+ for _, pkc := range pkColumns {
+ pkValues = append(pkValues,
expected[pkc.Name()].(string))
}
pkValueStr := strings.Join(pkValues, `-`)
_, ok := csvMap[pkValueStr]
@@ -414,9 +407,9 @@ func (t *DataFlowTester) VerifyTableWithOptions(dst
schema.Tabler, opts TableOpt
panic(err)
}
for _, actual := range *dbRows {
- pkValues := make([]string, 0, len(pkFields))
- for _, pkf := range pkFields {
- pkValues = append(pkValues, formatDbValue(actual[pkf]))
+ pkValues := make([]string, 0, len(pkColumns))
+ for _, pkc := range pkColumns {
+ pkValues = append(pkValues,
formatDbValue(actual[pkc.Name()]))
}
expected, ok := csvMap[strings.Join(pkValues, `-`)]
assert.True(t.T, ok, fmt.Sprintf(`%s not found (with params
from csv %s)`, dst.TableName(), pkValues))
diff --git a/helpers/e2ehelper/data_flow_tester_test.go
b/helpers/e2ehelper/data_flow_tester_test.go
index e98d792f..c13d9945 100644
--- a/helpers/e2ehelper/data_flow_tester_test.go
+++ b/helpers/e2ehelper/data_flow_tester_test.go
@@ -18,11 +18,12 @@ limitations under the License.
package e2ehelper
import (
+ "testing"
+
"github.com/apache/incubator-devlake/models/common"
+ "github.com/apache/incubator-devlake/plugins/core/dal"
gitlabModels "github.com/apache/incubator-devlake/plugins/gitlab/models"
"github.com/stretchr/testify/assert"
- "gorm.io/gorm"
- "testing"
"github.com/apache/incubator-devlake/plugins/core"
"github.com/apache/incubator-devlake/plugins/gitlab/tasks"
@@ -87,11 +88,10 @@ func TestGetTableMetaData(t *testing.T) {
var meta core.PluginMeta
dataflowTester := NewDataFlowTester(t, "test_dataflow", meta)
dataflowTester.FlushTabler(&TestModel{})
- t.Run("get_fields", func(t *testing.T) {
- fields := dataflowTester.getFields(&TestModel{}, func(column
gorm.ColumnType) bool {
- return true
- })
- assert.Equal(t, 9, len(fields))
+ t.Run("dal_get_columns", func(t *testing.T) {
+ names, err := dal.GetColumnNames(dataflowTester.Dal,
&TestModel{}, nil)
+ assert.Equal(t, err, nil)
+ assert.Equal(t, 9, len(names))
for _, e := range []string{
"connection_id",
"issue_id",
@@ -103,7 +103,7 @@ func TestGetTableMetaData(t *testing.T) {
"_raw_data_id",
"_raw_data_remark",
} {
- assert.Contains(t, fields, e)
+ assert.Contains(t, names, e)
}
})
t.Run("extract_columns", func(t *testing.T) {
@@ -118,8 +118,9 @@ func TestGetTableMetaData(t *testing.T) {
assert.Contains(t, columns, e)
}
})
- t.Run("get_pk_fields", func(t *testing.T) {
- fields := dataflowTester.getPkFields(&TestModel{})
+ t.Run("dal_get_pk_column_names", func(t *testing.T) {
+ fields, err := dal.GetPrimarykeyColumnNames(dataflowTester.Dal,
&TestModel{})
+ assert.Equal(t, err, nil)
assert.Equal(t, 3, len(fields))
for _, e := range []string{
"connection_id",
diff --git a/impl/dalgorm/dalgorm.go b/impl/dalgorm/dalgorm.go
index 773c2e06..d5ead148 100644
--- a/impl/dalgorm/dalgorm.go
+++ b/impl/dalgorm/dalgorm.go
@@ -19,27 +19,20 @@ package dalgorm
import (
"database/sql"
- "fmt"
+ "reflect"
"strings"
"github.com/apache/incubator-devlake/plugins/core/dal"
+ "github.com/apache/incubator-devlake/utils"
"gorm.io/gorm"
"gorm.io/gorm/clause"
+ "gorm.io/gorm/schema"
)
type Dalgorm struct {
db *gorm.DB
}
-// To accommodate gorm
-//type stubTable struct {
-//name string
-//}
-
-//func (s *stubTable) TableName() string {
-//return s.name
-//}
-
func buildTx(tx *gorm.DB, clauses []dal.Clause) *gorm.DB {
for _, c := range clauses {
t := c.Type
@@ -151,6 +144,28 @@ func (d *Dalgorm) Delete(entity interface{}, clauses
...dal.Clause) error {
return buildTx(d.db, clauses).Delete(entity).Error
}
+func (d *Dalgorm) GetColumns(dst schema.Tabler, filter func(columnMeta
dal.ColumnMeta) bool) (cms []dal.ColumnMeta, err error) {
+ columnTypes, err := d.db.Migrator().ColumnTypes(dst.TableName())
+ if err != nil {
+ return
+ }
+ for _, columnType := range columnTypes {
+ if filter == nil {
+ cms = append(cms, columnType)
+ } else if filter(columnType) {
+ cms = append(cms, columnType)
+ }
+ }
+ return
+}
+
+// GetPrimaryKey get the PrimaryKey from `gorm` tag
+func (d *Dalgorm) GetPrimarykeyFields(t reflect.Type) []reflect.StructField {
+ return utils.WalkFields(t, func(field *reflect.StructField) bool {
+ return strings.Contains(strings.ToLower(field.Tag.Get("gorm")),
"primarykey")
+ })
+}
+
// AllTables returns all tables in the database
func (d *Dalgorm) AllTables() ([]string, error) {
var tableSql string
@@ -173,41 +188,6 @@ func (d *Dalgorm) AllTables() ([]string, error) {
return filteredTables, nil
}
-// GetTableColumns returns table columns in database
-func (d *Dalgorm) GetTableColumns(table string) (map[string]string, error) {
- var columnSql string
- ret := make(map[string]string)
- if d.db.Dialector.Name() == "mysql" {
- type MySQLColumn struct {
- Field string
- Type string
- }
- var result []MySQLColumn
- columnSql = fmt.Sprintf("show columns from %s", table)
- err := d.db.Raw(columnSql).Scan(&result).Error
- if err != nil {
- return nil, err
- }
- for _, item := range result {
- ret[item.Field] = item.Type
- }
- } else {
- columnSql = fmt.Sprintf("select column_name,data_type from
information_schema.COLUMNS where TABLE_NAME='%s' and TABLE_SCHEMA='public'",
table)
- type PostgresColumn struct {
- ColumnName string `gorm:"column_name"`
- DataType string `gorm:"data_type"`
- }
- var result []PostgresColumn
- err := d.db.Raw(columnSql).Scan(&result).Error
- if err != nil {
- return nil, err
- }
- for _, item := range result {
- ret[item.ColumnName] = item.DataType
- }
- }
- return ret, nil
-}
func NewDalgorm(db *gorm.DB) *Dalgorm {
return &Dalgorm{db}
}
diff --git a/models/domainlayer/didgen/domain_id_generator.go
b/models/domainlayer/didgen/domain_id_generator.go
index 06a81f3a..09087a58 100644
--- a/models/domainlayer/didgen/domain_id_generator.go
+++ b/models/domainlayer/didgen/domain_id_generator.go
@@ -20,15 +20,14 @@ package didgen
import (
"fmt"
"reflect"
- "strings"
+ "github.com/apache/incubator-devlake/impl/dalgorm"
"github.com/apache/incubator-devlake/plugins/core"
)
type DomainIdGenerator struct {
- prefix string
- pkNames []string
- pkTypes []reflect.Type
+ prefix string
+ pk []reflect.StructField
}
type WildCard string
@@ -37,23 +36,6 @@ const WILDCARD WildCard = "%"
var wildcardType = reflect.TypeOf(WILDCARD)
-func walkFields(t reflect.Type, pkNames *[]string, pkTypes *[]reflect.Type) {
- for i := 0; i < t.NumField(); i++ {
- field := t.Field(i)
- if field.Type.Kind() == reflect.Struct {
- walkFields(field.Type, pkNames, pkTypes)
- } else {
- gormTag := field.Tag.Get("gorm")
-
- // TODO: regex?
- if gormTag != "" &&
strings.Contains(strings.ToLower(gormTag), "primarykey") {
- *pkNames = append(*pkNames, field.Name)
- *pkTypes = append(*pkTypes, field.Type)
- }
- }
- }
-}
-
func NewDomainIdGenerator(entityPtr interface{}) *DomainIdGenerator {
v := reflect.ValueOf(entityPtr)
if v.Kind() != reflect.Ptr {
@@ -69,20 +51,16 @@ func NewDomainIdGenerator(entityPtr interface{})
*DomainIdGenerator {
// find out entity type name
structName := t.Name()
- // find out all primkary keys and their types
- pkNames := make([]string, 0, 1)
- pkTypes := make([]reflect.Type, 0, 1)
-
- walkFields(t, &pkNames, &pkTypes)
+ dal := &dalgorm.Dalgorm{}
+ pk := dal.GetPrimarykeyFields(t)
- if len(pkNames) == 0 {
+ if len(pk) == 0 {
panic(fmt.Errorf("no primary key found for %s:%s", pluginName,
structName))
}
return &DomainIdGenerator{
- prefix: fmt.Sprintf("%s:%s", pluginName, structName),
- pkNames: pkNames,
- pkTypes: pkTypes,
+ prefix: fmt.Sprintf("%s:%s", pluginName, structName),
+ pk: pk,
}
}
@@ -95,11 +73,11 @@ func (g *DomainIdGenerator) Generate(pkValues
...interface{}) string {
pkValueType := reflect.TypeOf(pkValue)
if pkValueType == wildcardType {
break
- } else if pkValueType != g.pkTypes[i] {
+ } else if pkValueType != g.pk[i].Type {
panic(fmt.Errorf("primary key type does not match: %s
is %s type, and it should be %s type",
- g.pkNames[i],
+ g.pk[i].Name,
pkValueType.Name(),
- g.pkTypes[i].Name(),
+ g.pk[i].Type.Name(),
))
}
}
diff --git a/plugins/core/dal/dal.go b/plugins/core/dal/dal.go
index 85d74c3b..4cc14dab 100644
--- a/plugins/core/dal/dal.go
+++ b/plugins/core/dal/dal.go
@@ -19,6 +19,9 @@ package dal
import (
"database/sql"
+ "reflect"
+
+ "gorm.io/gorm/schema"
)
type Clause struct {
@@ -26,6 +29,22 @@ type Clause struct {
Data interface{}
}
+// ColumnType column type interface
+type ColumnMeta interface {
+ Name() string
+ DatabaseTypeName() string // varchar
+ ColumnType() (columnType string, ok bool) // varchar(64)
+ PrimaryKey() (isPrimaryKey bool, ok bool)
+ AutoIncrement() (isAutoIncrement bool, ok bool)
+ Length() (length int64, ok bool)
+ DecimalSize() (precision int64, scale int64, ok bool)
+ Nullable() (nullable bool, ok bool)
+ Unique() (unique bool, ok bool)
+ ScanType() reflect.Type
+ Comment() (value string, ok bool)
+ DefaultValue() (value string, ok bool)
+}
+
// Dal aims to facilitate an isolation between DBS and our System by defining
a set of operations should a DBS provide
type Dal interface {
// AutoMigrate runs auto migration for given entity
@@ -58,8 +77,42 @@ type Dal interface {
Delete(entity interface{}, clauses ...Clause) error
// AllTables returns all tables in database
AllTables() ([]string, error)
- // GetTableColumns returns table columns in database
- GetTableColumns(table string) (map[string]string, error)
+ // GetColumns returns table columns in database
+ GetColumns(dst schema.Tabler, filter func(columnMeta ColumnMeta) bool)
(cms []ColumnMeta, err error)
+ // GetPrimarykeyFields get the PrimaryKey from `gorm` tag
+ GetPrimarykeyFields(t reflect.Type) []reflect.StructField
+}
+
+// GetPrimarykeyColumnNames returns table Column Names in database
+func GetColumnNames(d Dal, dst schema.Tabler, filter func(columnMeta
ColumnMeta) bool) (names []string, err error) {
+ columns, err := d.GetColumns(dst, filter)
+ if err != nil {
+ return
+ }
+ for _, pkColumn := range columns {
+ names = append(names, pkColumn.Name())
+ }
+ return
+}
+
+// GetPrimarykeyColumns get returns PrimaryKey table Meta in database
+func GetPrimarykeyColumns(d Dal, dst schema.Tabler) ([]ColumnMeta, error) {
+ return d.GetColumns(dst, func(columnMeta ColumnMeta) bool {
+ isPrimaryKey, ok := columnMeta.PrimaryKey()
+ return isPrimaryKey && ok
+ })
+}
+
+// GetPrimarykeyColumnNames get returns PrimaryKey Column Names in database
+func GetPrimarykeyColumnNames(d Dal, dst schema.Tabler) (names []string, err
error) {
+ pkColumns, err := GetPrimarykeyColumns(d, dst)
+ if err != nil {
+ return
+ }
+ for _, pkColumn := range pkColumns {
+ names = append(names, pkColumn.Name())
+ }
+ return
}
type DalClause struct {
diff --git a/plugins/helper/batch_save.go b/plugins/helper/batch_save.go
index 28643fee..84c021d5 100644
--- a/plugins/helper/batch_save.go
+++ b/plugins/helper/batch_save.go
@@ -38,6 +38,7 @@ type BatchSave struct {
current int
size int
valueIndex map[string]int
+ primaryKey []reflect.StructField
}
const BATCH_SAVE_UPDATE_ONLY = 0
@@ -47,18 +48,23 @@ func NewBatchSave(basicRes core.BasicRes, slotType
reflect.Type, size int) (*Bat
if slotType.Kind() != reflect.Ptr {
return nil, fmt.Errorf("slotType must be a pointer")
}
- if !hasPrimaryKey(slotType) {
+ dal := basicRes.GetDal()
+ primaryKey := dal.GetPrimarykeyFields(slotType)
+ // check if it have primaryKey
+ if len(primaryKey) == 0 {
return nil, fmt.Errorf("%s no primary key", slotType.String())
}
+
log := basicRes.GetLogger().Nested(slotType.String())
return &BatchSave{
basicRes: basicRes,
log: log,
- db: basicRes.GetDal(),
+ db: dal,
slotType: slotType,
slots: reflect.MakeSlice(reflect.SliceOf(slotType), size,
size),
size: size,
valueIndex: make(map[string]int),
+ primaryKey: primaryKey,
}, nil
}
@@ -72,7 +78,7 @@ func (c *BatchSave) Add(slot interface{}) error {
return fmt.Errorf("slot is not a pointer")
}
// deduplication
- key := getPrimaryKeyValue(slot)
+ key := getKeyValue(slot, c.primaryKey)
if key != "" {
if index, ok := c.valueIndex[key]; !ok {
@@ -113,46 +119,17 @@ func (c *BatchSave) Close() error {
return nil
}
-func isPrimaryKey(f reflect.StructField) bool {
- tag := strings.TrimSpace(f.Tag.Get("gorm"))
- return strings.HasPrefix(strings.ToLower(tag), "primarykey")
-}
-
-func hasPrimaryKey(ifv reflect.Type) bool {
- if ifv.Kind() == reflect.Ptr {
- ifv = ifv.Elem()
- }
- for i := 0; i < ifv.NumField(); i++ {
- v := ifv.Field(i)
- if ok := isPrimaryKey(v); ok {
- return true
- } else if v.Type.Kind() == reflect.Struct {
- if ok := hasPrimaryKey(v.Type); ok {
- return true
- }
- }
- }
- return false
-}
-
-func getPrimaryKeyValue(iface interface{}) string {
+func getKeyValue(iface interface{}, primaryKey []reflect.StructField) string {
var ss []string
ifv := reflect.ValueOf(iface)
if ifv.Kind() == reflect.Ptr {
ifv = ifv.Elem()
}
- for i := 0; i < ifv.NumField(); i++ {
- v := ifv.Field(i)
- if isPrimaryKey(ifv.Type().Field(i)) {
- s := fmt.Sprintf("%v", v.Interface())
- if s != "" {
- ss = append(ss, s)
- }
- } else if v.Kind() == reflect.Struct {
- s := getPrimaryKeyValue(v.Interface())
- if s != "" {
- ss = append(ss, s)
- }
+ for _, key := range primaryKey {
+ v := ifv.FieldByName(key.Name)
+ s := fmt.Sprintf("%v", v.Interface())
+ if s != "" {
+ ss = append(ss, s)
}
}
return strings.Join(ss, ":")
diff --git a/plugins/helper/batch_save_divider_test.go
b/plugins/helper/batch_save_divider_test.go
index ae933a05..cabef8de 100644
--- a/plugins/helper/batch_save_divider_test.go
+++ b/plugins/helper/batch_save_divider_test.go
@@ -54,13 +54,19 @@ func TestBatchSaveDivider(t *testing.T) {
mockLog := unithelper.DummyLogger()
mockRes := new(mocks.BasicRes)
+
mockRes.On("GetDal").Return(mockDal)
mockRes.On("GetLogger").Return(mockLog)
- divider := NewBatchSaveDivider(mockRes, 10, "", "")
-
// we expect total 2 deletion calls after all code got carried out
mockDal.On("Delete", mock.Anything, mock.Anything).Return(nil).Twice()
+ mockDal.On("GetPrimarykeyFields", mock.Anything).Return(
+ []reflect.StructField{
+ {Name: "ID", Type: reflect.TypeOf("")},
+ },
+ )
+
+ divider := NewBatchSaveDivider(mockRes, 10, "", "")
// for same type should return the same BatchSave
jiraIssue1, err := divider.ForType(reflect.TypeOf(&MockJirIssueBsd{}))
diff --git a/plugins/helper/config_util.go b/plugins/helper/config_util.go
index 65a51a3c..ee51d007 100644
--- a/plugins/helper/config_util.go
+++ b/plugins/helper/config_util.go
@@ -21,6 +21,7 @@ import (
"fmt"
"reflect"
+ "github.com/apache/incubator-devlake/utils"
"github.com/go-playground/validator/v10"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
@@ -50,16 +51,11 @@ func DecodeStruct(output *viper.Viper, input interface{},
data map[string]interf
if vf.Kind() != reflect.Ptr {
return fmt.Errorf("input %v is not a pointer", input)
}
- tf := reflect.Indirect(vf).Type()
- fieldTags := make([]string, 0)
- fieldNames := make([]string, 0)
- fieldTypes := make([]reflect.Type, 0)
- walkFields(tf, &fieldNames, &fieldTypes, &fieldTags, tag)
- length := len(fieldNames)
- for i := 0; i < length; i++ {
- fieldName := fieldNames[i]
- fieldType := fieldTypes[i]
- fieldTag := fieldTags[i]
+
+ for _, f := range utils.WalkFields(reflect.Indirect(vf).Type(), nil) {
+ fieldName := f.Name
+ fieldType := f.Type
+ fieldTag := f.Tag.Get(tag)
// Check if the first letter is uppercase (indicates a public
element, accessible)
ascii := rune(fieldName[0])
@@ -114,16 +110,11 @@ func EncodeStruct(input *viper.Viper, output interface{},
tag string) error {
if vf.Kind() != reflect.Ptr {
return fmt.Errorf("output %v is not a pointer", output)
}
- tf := reflect.Indirect(vf).Type()
- fieldTags := make([]string, 0)
- fieldNames := make([]string, 0)
- fieldTypes := make([]reflect.Type, 0)
- walkFields(tf, &fieldNames, &fieldTypes, &fieldTags, tag)
- length := len(fieldNames)
- for i := 0; i < length; i++ {
- fieldName := fieldNames[i]
- fieldType := fieldTypes[i]
- fieldTag := fieldTags[i]
+
+ for _, f := range utils.WalkFields(reflect.Indirect(vf).Type(), nil) {
+ fieldName := f.Name
+ fieldType := f.Type
+ fieldTag := f.Tag.Get(tag)
// Check if the first letter is uppercase (indicates a public
element, accessible)
ascii := rune(fieldName[0])
@@ -186,17 +177,3 @@ func EncodeStruct(input *viper.Viper, output interface{},
tag string) error {
}
return nil
}
-
-func walkFields(t reflect.Type, fieldNames *[]string, fieldTypes
*[]reflect.Type, fieldTags *[]string, tag string) {
- for i := 0; i < t.NumField(); i++ {
- field := t.Field(i)
- if field.Type.Kind() == reflect.Struct {
- walkFields(field.Type, fieldNames, fieldTypes,
fieldTags, tag)
- } else {
- fieldTag := field.Tag.Get(tag)
- *fieldNames = append(*fieldNames, field.Name)
- *fieldTypes = append(*fieldTypes, field.Type)
- *fieldTags = append(*fieldTags, fieldTag)
- }
- }
-}
diff --git a/plugins/starrocks/tasks.go b/plugins/starrocks/tasks.go
index b14598b1..be98a7c6 100644
--- a/plugins/starrocks/tasks.go
+++ b/plugins/starrocks/tasks.go
@@ -30,6 +30,14 @@ import (
"github.com/apache/incubator-devlake/plugins/core/dal"
)
+type Table struct {
+ name string
+}
+
+func (t *Table) TableName() string {
+ return t.name
+}
+
func LoadData(c core.SubTaskContext) error {
config := c.GetData().(*StarRocksConfig)
db := c.GetDal()
@@ -76,33 +84,47 @@ func LoadData(c core.SubTaskContext) error {
return nil
}
func createTable(starrocks *sql.DB, db dal.Dal, starrocksTable string, table
string, c core.SubTaskContext, extra string) error {
- columnMap, err := db.GetTableColumns(table)
+
+ columeMetas, err := db.GetColumns(&Table{name: table}, nil)
if err != nil {
return err
}
- var pk string
- if _, ok := columnMap["id"]; ok {
- pk = "id"
- } else {
- for k := range columnMap {
- pk = k
- break
- }
- }
+ var pks string
var columns []string
- for field, dataType := range columnMap {
- starrocksDatatype := getDataType(dataType)
- column := fmt.Sprintf("%s %s", field, starrocksDatatype)
+ firstcm := ""
+ for _, cm := range columeMetas {
+ name := cm.Name()
+ starrocksDatatype, ok := cm.ColumnType()
+ if !ok {
+ return fmt.Errorf("Get [%s] ColumeType Failed", name)
+ }
+ column := fmt.Sprintf("%s %s", name, starrocksDatatype)
columns = append(columns, column)
+ isPrimaryKey, ok := cm.PrimaryKey()
+ if isPrimaryKey && ok {
+ if pks != "" {
+ pks += ","
+ }
+ pks += name
+ }
+ if firstcm == "" {
+ firstcm = name
+ }
}
+
+ if pks == "" {
+ pks = firstcm
+ }
+
if extra == "" {
- extra = fmt.Sprintf(`engine=olap distributed by hash(%s)
properties("replication_num" = "1")`, pk)
+ extra = fmt.Sprintf(`engine=olap distributed by hash(%s)
properties("replication_num" = "1")`, pks)
}
tableSql := fmt.Sprintf(`create table if not exists %s ( %s ) %s`,
starrocksTable, strings.Join(columns, ","), extra)
c.GetLogger().Info(tableSql)
_, err = starrocks.Exec(tableSql)
return err
}
+
func loadData(starrocks *sql.DB, c core.SubTaskContext, starrocksTable string,
table string, db dal.Dal, config *StarRocksConfig) error {
offset := 0
starrocksTmpTable := starrocksTable + "_tmp"
diff --git a/plugins/starrocks/utils.go b/utils/structfield.go
similarity index 57%
rename from plugins/starrocks/utils.go
rename to utils/structfield.go
index 01415af1..4df76282 100644
--- a/plugins/starrocks/utils.go
+++ b/utils/structfield.go
@@ -14,22 +14,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
-package main
-import "strings"
+package utils
-func getDataType(dataType string) string {
- starrocksDatatype := dataType
- if strings.HasPrefix(dataType, "varchar") {
- starrocksDatatype = "string"
- } else if strings.HasPrefix(dataType, "datetime") {
- starrocksDatatype = "datetime"
- } else if strings.HasPrefix(dataType, "bigint") {
- starrocksDatatype = "bigint"
- } else if dataType == "longtext" || dataType == "text" || dataType ==
"longblob" {
- starrocksDatatype = "string"
- } else if dataType == "tinyint(1)" {
- starrocksDatatype = "boolean"
+import (
+ "reflect"
+)
+
+// WalkFiled get the field data by tag
+func WalkFields(t reflect.Type, filter func(field *reflect.StructField) bool)
(f []reflect.StructField) {
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ for i := 0; i < t.NumField(); i++ {
+ field := t.Field(i)
+ if field.Type.Kind() == reflect.Struct {
+ f = append(f, WalkFields(field.Type, filter)...)
+ } else {
+ if filter == nil {
+ f = append(f, field)
+ } else if filter(&field) {
+ f = append(f, field)
+ }
+ }
}
- return starrocksDatatype
+ return f
}