This is an automated email from the ASF dual-hosted git repository.

mappjzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-devlake.git


The following commit(s) were added to refs/heads/main by this push:
     new 04bb9d42 refactor: reflect for get primary key (#2422)
04bb9d42 is described below

commit 04bb9d42018001db51983fddcb374b7bbe18fefa
Author: mappjzc <[email protected]>
AuthorDate: Thu Jul 7 16:22:51 2022 +0800

    refactor: reflect for get primary key (#2422)
    
    * refactor: reflect for get primary key
    
    Add GetPrimaryKeyNameFromDB to dal
    Add GetPrimaryKey to dal
    Add WalkFiled to utils
    
    Nddtfjiang <[email protected]>
    
    * refactor: rename get primarykey
    
    Rename GetPrimaryKeyNameFromDB to GetPrimarykeyColumnNames.
    Rename GetPrimaryKey to GetPrimarykeyFields.
    
    Nddtfjiang <[email protected]>
---
 helpers/e2ehelper/data_flow_tester.go              | 75 ++++++++++------------
 helpers/e2ehelper/data_flow_tester_test.go         | 21 +++---
 impl/dalgorm/dalgorm.go                            | 70 ++++++++------------
 models/domainlayer/didgen/domain_id_generator.go   | 44 ++++---------
 plugins/core/dal/dal.go                            | 57 +++++++++++++++-
 plugins/helper/batch_save.go                       | 53 +++++----------
 plugins/helper/batch_save_divider_test.go          | 10 ++-
 plugins/helper/config_util.go                      | 45 ++++---------
 plugins/starrocks/tasks.go                         | 50 +++++++++++----
 plugins/starrocks/utils.go => utils/structfield.go | 37 ++++++-----
 10 files changed, 228 insertions(+), 234 deletions(-)

diff --git a/helpers/e2ehelper/data_flow_tester.go 
b/helpers/e2ehelper/data_flow_tester.go
index 28c33875..26cdca16 100644
--- a/helpers/e2ehelper/data_flow_tester.go
+++ b/helpers/e2ehelper/data_flow_tester.go
@@ -22,6 +22,13 @@ import (
        "database/sql"
        "encoding/json"
        "fmt"
+       "os"
+       "strconv"
+       "strings"
+       "sync"
+       "testing"
+       "time"
+
        "github.com/apache/incubator-devlake/config"
        "github.com/apache/incubator-devlake/helpers/pluginhelper"
        "github.com/apache/incubator-devlake/impl/dalgorm"
@@ -35,12 +42,6 @@ import (
        "github.com/stretchr/testify/assert"
        "gorm.io/gorm"
        "gorm.io/gorm/schema"
-       "os"
-       "strconv"
-       "strings"
-       "sync"
-       "testing"
-       "time"
 )
 
 // DataFlowTester provides a universal data integrity validation facility to 
help `Plugin` verifying records between
@@ -186,14 +187,7 @@ func (t *DataFlowTester) Subtask(subtaskMeta 
core.SubTaskMeta, taskData interfac
        }
 }
 
-func (t *DataFlowTester) getPkFields(dst schema.Tabler) []string {
-       return t.getFields(dst, func(column gorm.ColumnType) bool {
-               isPk, _ := column.PrimaryKey()
-               return isPk
-       })
-}
-
-func filterColumn(column gorm.ColumnType, opts TableOptions) bool {
+func filterColumn(column dal.ColumnMeta, opts TableOptions) bool {
        for _, ignore := range opts.IgnoreFields {
                if column.Name() == ignore {
                        return false
@@ -212,31 +206,21 @@ func filterColumn(column gorm.ColumnType, opts 
TableOptions) bool {
        return targetFound
 }
 
-func (t *DataFlowTester) getFields(dst schema.Tabler, filter func(column 
gorm.ColumnType) bool) []string {
-       columnTypes, err := t.Db.Migrator().ColumnTypes(dst)
-       var fields []string
-       if err != nil {
-               panic(err)
-       }
-       for _, columnType := range columnTypes {
-               if filter == nil || filter(columnType) {
-                       fields = append(fields, columnType.Name())
-               }
-       }
-       return fields
-}
-
 // CreateSnapshot reads rows from database and write them into .csv file.
 func (t *DataFlowTester) CreateSnapshot(dst schema.Tabler, opts TableOptions) {
        location, _ := time.LoadLocation(`UTC`)
-       pkFields := t.getPkFields(dst)
+
        targetFields := t.resolveTargetFields(dst, opts)
-       allFields := append(pkFields, targetFields...)
+       pkColumnNames, err := dal.GetPrimarykeyColumnNames(t.Dal, dst)
+       if err != nil {
+               panic(err)
+       }
+       allFields := append(pkColumnNames, targetFields...)
        allFields = utils.StringsUniq(allFields)
        dbCursor, err := t.Dal.Cursor(
                dal.Select(strings.Join(allFields, `,`)),
                dal.From(dst.TableName()),
-               dal.Orderby(strings.Join(pkFields, `,`)),
+               dal.Orderby(strings.Join(pkColumnNames, `,`)),
        )
        if err != nil {
                panic(fmt.Errorf("unable to run select query on table %s: %v", 
dst.TableName(), err))
@@ -369,9 +353,13 @@ func (t *DataFlowTester) resolveTargetFields(dst 
schema.Tabler, opts TableOption
        }
        var targetFields []string
        if len(opts.TargetFields) == 0 || len(opts.IgnoreFields) > 0 {
-               targetFields = append(targetFields, t.getFields(dst, 
func(column gorm.ColumnType) bool {
-                       return filterColumn(column, opts)
-               })...)
+               names, err := dal.GetColumnNames(t.Dal, dst, func(cm 
dal.ColumnMeta) bool {
+                       return filterColumn(cm, opts)
+               })
+               if err != nil {
+                       panic(err)
+               }
+               targetFields = append(targetFields, names...)
        } else {
                targetFields = opts.TargetFields
        }
@@ -388,17 +376,22 @@ func (t *DataFlowTester) VerifyTableWithOptions(dst 
schema.Tabler, opts TableOpt
                t.CreateSnapshot(dst, opts)
                return
        }
+
        targetFields := t.resolveTargetFields(dst, opts)
-       pkFields := t.getPkFields(dst)
+       pkColumns, err := dal.GetPrimarykeyColumns(t.Dal, dst)
+       if err != nil {
+               panic(err)
+       }
+
        csvIter := pluginhelper.NewCsvFileIterator(opts.CSVRelPath)
        defer csvIter.Close()
        expectedTotal := int64(0)
        csvMap := map[string]map[string]interface{}{}
        for csvIter.HasNext() {
                expected := csvIter.Fetch()
-               pkValues := make([]string, 0, len(pkFields))
-               for _, pkf := range pkFields {
-                       pkValues = append(pkValues, expected[pkf].(string))
+               pkValues := make([]string, 0, len(pkColumns))
+               for _, pkc := range pkColumns {
+                       pkValues = append(pkValues, 
expected[pkc.Name()].(string))
                }
                pkValueStr := strings.Join(pkValues, `-`)
                _, ok := csvMap[pkValueStr]
@@ -414,9 +407,9 @@ func (t *DataFlowTester) VerifyTableWithOptions(dst 
schema.Tabler, opts TableOpt
                panic(err)
        }
        for _, actual := range *dbRows {
-               pkValues := make([]string, 0, len(pkFields))
-               for _, pkf := range pkFields {
-                       pkValues = append(pkValues, formatDbValue(actual[pkf]))
+               pkValues := make([]string, 0, len(pkColumns))
+               for _, pkc := range pkColumns {
+                       pkValues = append(pkValues, 
formatDbValue(actual[pkc.Name()]))
                }
                expected, ok := csvMap[strings.Join(pkValues, `-`)]
                assert.True(t.T, ok, fmt.Sprintf(`%s not found (with params 
from csv %s)`, dst.TableName(), pkValues))
diff --git a/helpers/e2ehelper/data_flow_tester_test.go 
b/helpers/e2ehelper/data_flow_tester_test.go
index e98d792f..c13d9945 100644
--- a/helpers/e2ehelper/data_flow_tester_test.go
+++ b/helpers/e2ehelper/data_flow_tester_test.go
@@ -18,11 +18,12 @@ limitations under the License.
 package e2ehelper
 
 import (
+       "testing"
+
        "github.com/apache/incubator-devlake/models/common"
+       "github.com/apache/incubator-devlake/plugins/core/dal"
        gitlabModels "github.com/apache/incubator-devlake/plugins/gitlab/models"
        "github.com/stretchr/testify/assert"
-       "gorm.io/gorm"
-       "testing"
 
        "github.com/apache/incubator-devlake/plugins/core"
        "github.com/apache/incubator-devlake/plugins/gitlab/tasks"
@@ -87,11 +88,10 @@ func TestGetTableMetaData(t *testing.T) {
        var meta core.PluginMeta
        dataflowTester := NewDataFlowTester(t, "test_dataflow", meta)
        dataflowTester.FlushTabler(&TestModel{})
-       t.Run("get_fields", func(t *testing.T) {
-               fields := dataflowTester.getFields(&TestModel{}, func(column 
gorm.ColumnType) bool {
-                       return true
-               })
-               assert.Equal(t, 9, len(fields))
+       t.Run("dal_get_columns", func(t *testing.T) {
+               names, err := dal.GetColumnNames(dataflowTester.Dal, 
&TestModel{}, nil)
+               assert.Equal(t, err, nil)
+               assert.Equal(t, 9, len(names))
                for _, e := range []string{
                        "connection_id",
                        "issue_id",
@@ -103,7 +103,7 @@ func TestGetTableMetaData(t *testing.T) {
                        "_raw_data_id",
                        "_raw_data_remark",
                } {
-                       assert.Contains(t, fields, e)
+                       assert.Contains(t, names, e)
                }
        })
        t.Run("extract_columns", func(t *testing.T) {
@@ -118,8 +118,9 @@ func TestGetTableMetaData(t *testing.T) {
                        assert.Contains(t, columns, e)
                }
        })
-       t.Run("get_pk_fields", func(t *testing.T) {
-               fields := dataflowTester.getPkFields(&TestModel{})
+       t.Run("dal_get_pk_column_names", func(t *testing.T) {
+               fields, err := dal.GetPrimarykeyColumnNames(dataflowTester.Dal, 
&TestModel{})
+               assert.Equal(t, err, nil)
                assert.Equal(t, 3, len(fields))
                for _, e := range []string{
                        "connection_id",
diff --git a/impl/dalgorm/dalgorm.go b/impl/dalgorm/dalgorm.go
index 773c2e06..d5ead148 100644
--- a/impl/dalgorm/dalgorm.go
+++ b/impl/dalgorm/dalgorm.go
@@ -19,27 +19,20 @@ package dalgorm
 
 import (
        "database/sql"
-       "fmt"
+       "reflect"
        "strings"
 
        "github.com/apache/incubator-devlake/plugins/core/dal"
+       "github.com/apache/incubator-devlake/utils"
        "gorm.io/gorm"
        "gorm.io/gorm/clause"
+       "gorm.io/gorm/schema"
 )
 
 type Dalgorm struct {
        db *gorm.DB
 }
 
-// To accommodate gorm
-//type stubTable struct {
-//name string
-//}
-
-//func (s *stubTable) TableName() string {
-//return s.name
-//}
-
 func buildTx(tx *gorm.DB, clauses []dal.Clause) *gorm.DB {
        for _, c := range clauses {
                t := c.Type
@@ -151,6 +144,28 @@ func (d *Dalgorm) Delete(entity interface{}, clauses 
...dal.Clause) error {
        return buildTx(d.db, clauses).Delete(entity).Error
 }
 
+func (d *Dalgorm) GetColumns(dst schema.Tabler, filter func(columnMeta 
dal.ColumnMeta) bool) (cms []dal.ColumnMeta, err error) {
+       columnTypes, err := d.db.Migrator().ColumnTypes(dst.TableName())
+       if err != nil {
+               return
+       }
+       for _, columnType := range columnTypes {
+               if filter == nil {
+                       cms = append(cms, columnType)
+               } else if filter(columnType) {
+                       cms = append(cms, columnType)
+               }
+       }
+       return
+}
+
+// GetPrimaryKey get the PrimaryKey from `gorm` tag
+func (d *Dalgorm) GetPrimarykeyFields(t reflect.Type) []reflect.StructField {
+       return utils.WalkFields(t, func(field *reflect.StructField) bool {
+               return strings.Contains(strings.ToLower(field.Tag.Get("gorm")), 
"primarykey")
+       })
+}
+
 // AllTables returns all tables in the database
 func (d *Dalgorm) AllTables() ([]string, error) {
        var tableSql string
@@ -173,41 +188,6 @@ func (d *Dalgorm) AllTables() ([]string, error) {
        return filteredTables, nil
 }
 
-// GetTableColumns returns table columns in database
-func (d *Dalgorm) GetTableColumns(table string) (map[string]string, error) {
-       var columnSql string
-       ret := make(map[string]string)
-       if d.db.Dialector.Name() == "mysql" {
-               type MySQLColumn struct {
-                       Field string
-                       Type  string
-               }
-               var result []MySQLColumn
-               columnSql = fmt.Sprintf("show columns from %s", table)
-               err := d.db.Raw(columnSql).Scan(&result).Error
-               if err != nil {
-                       return nil, err
-               }
-               for _, item := range result {
-                       ret[item.Field] = item.Type
-               }
-       } else {
-               columnSql = fmt.Sprintf("select column_name,data_type from 
information_schema.COLUMNS where TABLE_NAME='%s' and TABLE_SCHEMA='public'", 
table)
-               type PostgresColumn struct {
-                       ColumnName string `gorm:"column_name"`
-                       DataType   string `gorm:"data_type"`
-               }
-               var result []PostgresColumn
-               err := d.db.Raw(columnSql).Scan(&result).Error
-               if err != nil {
-                       return nil, err
-               }
-               for _, item := range result {
-                       ret[item.ColumnName] = item.DataType
-               }
-       }
-       return ret, nil
-}
 func NewDalgorm(db *gorm.DB) *Dalgorm {
        return &Dalgorm{db}
 }
diff --git a/models/domainlayer/didgen/domain_id_generator.go 
b/models/domainlayer/didgen/domain_id_generator.go
index 06a81f3a..09087a58 100644
--- a/models/domainlayer/didgen/domain_id_generator.go
+++ b/models/domainlayer/didgen/domain_id_generator.go
@@ -20,15 +20,14 @@ package didgen
 import (
        "fmt"
        "reflect"
-       "strings"
 
+       "github.com/apache/incubator-devlake/impl/dalgorm"
        "github.com/apache/incubator-devlake/plugins/core"
 )
 
 type DomainIdGenerator struct {
-       prefix  string
-       pkNames []string
-       pkTypes []reflect.Type
+       prefix string
+       pk     []reflect.StructField
 }
 
 type WildCard string
@@ -37,23 +36,6 @@ const WILDCARD WildCard = "%"
 
 var wildcardType = reflect.TypeOf(WILDCARD)
 
-func walkFields(t reflect.Type, pkNames *[]string, pkTypes *[]reflect.Type) {
-       for i := 0; i < t.NumField(); i++ {
-               field := t.Field(i)
-               if field.Type.Kind() == reflect.Struct {
-                       walkFields(field.Type, pkNames, pkTypes)
-               } else {
-                       gormTag := field.Tag.Get("gorm")
-
-                       // TODO: regex?
-                       if gormTag != "" && 
strings.Contains(strings.ToLower(gormTag), "primarykey") {
-                               *pkNames = append(*pkNames, field.Name)
-                               *pkTypes = append(*pkTypes, field.Type)
-                       }
-               }
-       }
-}
-
 func NewDomainIdGenerator(entityPtr interface{}) *DomainIdGenerator {
        v := reflect.ValueOf(entityPtr)
        if v.Kind() != reflect.Ptr {
@@ -69,20 +51,16 @@ func NewDomainIdGenerator(entityPtr interface{}) 
*DomainIdGenerator {
        // find out entity type name
        structName := t.Name()
 
-       // find out all primkary keys and their types
-       pkNames := make([]string, 0, 1)
-       pkTypes := make([]reflect.Type, 0, 1)
-
-       walkFields(t, &pkNames, &pkTypes)
+       dal := &dalgorm.Dalgorm{}
+       pk := dal.GetPrimarykeyFields(t)
 
-       if len(pkNames) == 0 {
+       if len(pk) == 0 {
                panic(fmt.Errorf("no primary key found for %s:%s", pluginName, 
structName))
        }
 
        return &DomainIdGenerator{
-               prefix:  fmt.Sprintf("%s:%s", pluginName, structName),
-               pkNames: pkNames,
-               pkTypes: pkTypes,
+               prefix: fmt.Sprintf("%s:%s", pluginName, structName),
+               pk:     pk,
        }
 }
 
@@ -95,11 +73,11 @@ func (g *DomainIdGenerator) Generate(pkValues 
...interface{}) string {
                pkValueType := reflect.TypeOf(pkValue)
                if pkValueType == wildcardType {
                        break
-               } else if pkValueType != g.pkTypes[i] {
+               } else if pkValueType != g.pk[i].Type {
                        panic(fmt.Errorf("primary key type does not match: %s 
is %s type, and it should be %s type",
-                               g.pkNames[i],
+                               g.pk[i].Name,
                                pkValueType.Name(),
-                               g.pkTypes[i].Name(),
+                               g.pk[i].Type.Name(),
                        ))
                }
        }
diff --git a/plugins/core/dal/dal.go b/plugins/core/dal/dal.go
index 85d74c3b..4cc14dab 100644
--- a/plugins/core/dal/dal.go
+++ b/plugins/core/dal/dal.go
@@ -19,6 +19,9 @@ package dal
 
 import (
        "database/sql"
+       "reflect"
+
+       "gorm.io/gorm/schema"
 )
 
 type Clause struct {
@@ -26,6 +29,22 @@ type Clause struct {
        Data interface{}
 }
 
+// ColumnType column type interface
+type ColumnMeta interface {
+       Name() string
+       DatabaseTypeName() string                 // varchar
+       ColumnType() (columnType string, ok bool) // varchar(64)
+       PrimaryKey() (isPrimaryKey bool, ok bool)
+       AutoIncrement() (isAutoIncrement bool, ok bool)
+       Length() (length int64, ok bool)
+       DecimalSize() (precision int64, scale int64, ok bool)
+       Nullable() (nullable bool, ok bool)
+       Unique() (unique bool, ok bool)
+       ScanType() reflect.Type
+       Comment() (value string, ok bool)
+       DefaultValue() (value string, ok bool)
+}
+
 // Dal aims to facilitate an isolation between DBS and our System by defining 
a set of operations should a DBS provide
 type Dal interface {
        // AutoMigrate runs auto migration for given entity
@@ -58,8 +77,42 @@ type Dal interface {
        Delete(entity interface{}, clauses ...Clause) error
        // AllTables returns all tables in database
        AllTables() ([]string, error)
-       // GetTableColumns returns table columns in database
-       GetTableColumns(table string) (map[string]string, error)
+       // GetColumns returns table columns in database
+       GetColumns(dst schema.Tabler, filter func(columnMeta ColumnMeta) bool) 
(cms []ColumnMeta, err error)
+       // GetPrimarykeyFields get the PrimaryKey from `gorm` tag
+       GetPrimarykeyFields(t reflect.Type) []reflect.StructField
+}
+
+// GetPrimarykeyColumnNames returns table Column Names in database
+func GetColumnNames(d Dal, dst schema.Tabler, filter func(columnMeta 
ColumnMeta) bool) (names []string, err error) {
+       columns, err := d.GetColumns(dst, filter)
+       if err != nil {
+               return
+       }
+       for _, pkColumn := range columns {
+               names = append(names, pkColumn.Name())
+       }
+       return
+}
+
+// GetPrimarykeyColumns get returns PrimaryKey table Meta in database
+func GetPrimarykeyColumns(d Dal, dst schema.Tabler) ([]ColumnMeta, error) {
+       return d.GetColumns(dst, func(columnMeta ColumnMeta) bool {
+               isPrimaryKey, ok := columnMeta.PrimaryKey()
+               return isPrimaryKey && ok
+       })
+}
+
+// GetPrimarykeyColumnNames get returns PrimaryKey Column Names in database
+func GetPrimarykeyColumnNames(d Dal, dst schema.Tabler) (names []string, err 
error) {
+       pkColumns, err := GetPrimarykeyColumns(d, dst)
+       if err != nil {
+               return
+       }
+       for _, pkColumn := range pkColumns {
+               names = append(names, pkColumn.Name())
+       }
+       return
 }
 
 type DalClause struct {
diff --git a/plugins/helper/batch_save.go b/plugins/helper/batch_save.go
index 28643fee..84c021d5 100644
--- a/plugins/helper/batch_save.go
+++ b/plugins/helper/batch_save.go
@@ -38,6 +38,7 @@ type BatchSave struct {
        current    int
        size       int
        valueIndex map[string]int
+       primaryKey []reflect.StructField
 }
 
 const BATCH_SAVE_UPDATE_ONLY = 0
@@ -47,18 +48,23 @@ func NewBatchSave(basicRes core.BasicRes, slotType 
reflect.Type, size int) (*Bat
        if slotType.Kind() != reflect.Ptr {
                return nil, fmt.Errorf("slotType must be a pointer")
        }
-       if !hasPrimaryKey(slotType) {
+       dal := basicRes.GetDal()
+       primaryKey := dal.GetPrimarykeyFields(slotType)
+       // check if it have primaryKey
+       if len(primaryKey) == 0 {
                return nil, fmt.Errorf("%s no primary key", slotType.String())
        }
+
        log := basicRes.GetLogger().Nested(slotType.String())
        return &BatchSave{
                basicRes:   basicRes,
                log:        log,
-               db:         basicRes.GetDal(),
+               db:         dal,
                slotType:   slotType,
                slots:      reflect.MakeSlice(reflect.SliceOf(slotType), size, 
size),
                size:       size,
                valueIndex: make(map[string]int),
+               primaryKey: primaryKey,
        }, nil
 }
 
@@ -72,7 +78,7 @@ func (c *BatchSave) Add(slot interface{}) error {
                return fmt.Errorf("slot is not a pointer")
        }
        // deduplication
-       key := getPrimaryKeyValue(slot)
+       key := getKeyValue(slot, c.primaryKey)
 
        if key != "" {
                if index, ok := c.valueIndex[key]; !ok {
@@ -113,46 +119,17 @@ func (c *BatchSave) Close() error {
        return nil
 }
 
-func isPrimaryKey(f reflect.StructField) bool {
-       tag := strings.TrimSpace(f.Tag.Get("gorm"))
-       return strings.HasPrefix(strings.ToLower(tag), "primarykey")
-}
-
-func hasPrimaryKey(ifv reflect.Type) bool {
-       if ifv.Kind() == reflect.Ptr {
-               ifv = ifv.Elem()
-       }
-       for i := 0; i < ifv.NumField(); i++ {
-               v := ifv.Field(i)
-               if ok := isPrimaryKey(v); ok {
-                       return true
-               } else if v.Type.Kind() == reflect.Struct {
-                       if ok := hasPrimaryKey(v.Type); ok {
-                               return true
-                       }
-               }
-       }
-       return false
-}
-
-func getPrimaryKeyValue(iface interface{}) string {
+func getKeyValue(iface interface{}, primaryKey []reflect.StructField) string {
        var ss []string
        ifv := reflect.ValueOf(iface)
        if ifv.Kind() == reflect.Ptr {
                ifv = ifv.Elem()
        }
-       for i := 0; i < ifv.NumField(); i++ {
-               v := ifv.Field(i)
-               if isPrimaryKey(ifv.Type().Field(i)) {
-                       s := fmt.Sprintf("%v", v.Interface())
-                       if s != "" {
-                               ss = append(ss, s)
-                       }
-               } else if v.Kind() == reflect.Struct {
-                       s := getPrimaryKeyValue(v.Interface())
-                       if s != "" {
-                               ss = append(ss, s)
-                       }
+       for _, key := range primaryKey {
+               v := ifv.FieldByName(key.Name)
+               s := fmt.Sprintf("%v", v.Interface())
+               if s != "" {
+                       ss = append(ss, s)
                }
        }
        return strings.Join(ss, ":")
diff --git a/plugins/helper/batch_save_divider_test.go 
b/plugins/helper/batch_save_divider_test.go
index ae933a05..cabef8de 100644
--- a/plugins/helper/batch_save_divider_test.go
+++ b/plugins/helper/batch_save_divider_test.go
@@ -54,13 +54,19 @@ func TestBatchSaveDivider(t *testing.T) {
 
        mockLog := unithelper.DummyLogger()
        mockRes := new(mocks.BasicRes)
+
        mockRes.On("GetDal").Return(mockDal)
        mockRes.On("GetLogger").Return(mockLog)
 
-       divider := NewBatchSaveDivider(mockRes, 10, "", "")
-
        // we expect total 2 deletion calls after all code got carried out
        mockDal.On("Delete", mock.Anything, mock.Anything).Return(nil).Twice()
+       mockDal.On("GetPrimarykeyFields", mock.Anything).Return(
+               []reflect.StructField{
+                       {Name: "ID", Type: reflect.TypeOf("")},
+               },
+       )
+
+       divider := NewBatchSaveDivider(mockRes, 10, "", "")
 
        // for same type should return the same BatchSave
        jiraIssue1, err := divider.ForType(reflect.TypeOf(&MockJirIssueBsd{}))
diff --git a/plugins/helper/config_util.go b/plugins/helper/config_util.go
index 65a51a3c..ee51d007 100644
--- a/plugins/helper/config_util.go
+++ b/plugins/helper/config_util.go
@@ -21,6 +21,7 @@ import (
        "fmt"
        "reflect"
 
+       "github.com/apache/incubator-devlake/utils"
        "github.com/go-playground/validator/v10"
        "github.com/mitchellh/mapstructure"
        "github.com/spf13/viper"
@@ -50,16 +51,11 @@ func DecodeStruct(output *viper.Viper, input interface{}, 
data map[string]interf
        if vf.Kind() != reflect.Ptr {
                return fmt.Errorf("input %v is not a pointer", input)
        }
-       tf := reflect.Indirect(vf).Type()
-       fieldTags := make([]string, 0)
-       fieldNames := make([]string, 0)
-       fieldTypes := make([]reflect.Type, 0)
-       walkFields(tf, &fieldNames, &fieldTypes, &fieldTags, tag)
-       length := len(fieldNames)
-       for i := 0; i < length; i++ {
-               fieldName := fieldNames[i]
-               fieldType := fieldTypes[i]
-               fieldTag := fieldTags[i]
+
+       for _, f := range utils.WalkFields(reflect.Indirect(vf).Type(), nil) {
+               fieldName := f.Name
+               fieldType := f.Type
+               fieldTag := f.Tag.Get(tag)
 
                // Check if the first letter is uppercase (indicates a public 
element, accessible)
                ascii := rune(fieldName[0])
@@ -114,16 +110,11 @@ func EncodeStruct(input *viper.Viper, output interface{}, 
tag string) error {
        if vf.Kind() != reflect.Ptr {
                return fmt.Errorf("output %v is not a pointer", output)
        }
-       tf := reflect.Indirect(vf).Type()
-       fieldTags := make([]string, 0)
-       fieldNames := make([]string, 0)
-       fieldTypes := make([]reflect.Type, 0)
-       walkFields(tf, &fieldNames, &fieldTypes, &fieldTags, tag)
-       length := len(fieldNames)
-       for i := 0; i < length; i++ {
-               fieldName := fieldNames[i]
-               fieldType := fieldTypes[i]
-               fieldTag := fieldTags[i]
+
+       for _, f := range utils.WalkFields(reflect.Indirect(vf).Type(), nil) {
+               fieldName := f.Name
+               fieldType := f.Type
+               fieldTag := f.Tag.Get(tag)
 
                // Check if the first letter is uppercase (indicates a public 
element, accessible)
                ascii := rune(fieldName[0])
@@ -186,17 +177,3 @@ func EncodeStruct(input *viper.Viper, output interface{}, 
tag string) error {
        }
        return nil
 }
-
-func walkFields(t reflect.Type, fieldNames *[]string, fieldTypes 
*[]reflect.Type, fieldTags *[]string, tag string) {
-       for i := 0; i < t.NumField(); i++ {
-               field := t.Field(i)
-               if field.Type.Kind() == reflect.Struct {
-                       walkFields(field.Type, fieldNames, fieldTypes, 
fieldTags, tag)
-               } else {
-                       fieldTag := field.Tag.Get(tag)
-                       *fieldNames = append(*fieldNames, field.Name)
-                       *fieldTypes = append(*fieldTypes, field.Type)
-                       *fieldTags = append(*fieldTags, fieldTag)
-               }
-       }
-}
diff --git a/plugins/starrocks/tasks.go b/plugins/starrocks/tasks.go
index b14598b1..be98a7c6 100644
--- a/plugins/starrocks/tasks.go
+++ b/plugins/starrocks/tasks.go
@@ -30,6 +30,14 @@ import (
        "github.com/apache/incubator-devlake/plugins/core/dal"
 )
 
+type Table struct {
+       name string
+}
+
+func (t *Table) TableName() string {
+       return t.name
+}
+
 func LoadData(c core.SubTaskContext) error {
        config := c.GetData().(*StarRocksConfig)
        db := c.GetDal()
@@ -76,33 +84,47 @@ func LoadData(c core.SubTaskContext) error {
        return nil
 }
 func createTable(starrocks *sql.DB, db dal.Dal, starrocksTable string, table 
string, c core.SubTaskContext, extra string) error {
-       columnMap, err := db.GetTableColumns(table)
+
+       columeMetas, err := db.GetColumns(&Table{name: table}, nil)
        if err != nil {
                return err
        }
-       var pk string
-       if _, ok := columnMap["id"]; ok {
-               pk = "id"
-       } else {
-               for k := range columnMap {
-                       pk = k
-                       break
-               }
-       }
+       var pks string
        var columns []string
-       for field, dataType := range columnMap {
-               starrocksDatatype := getDataType(dataType)
-               column := fmt.Sprintf("%s %s", field, starrocksDatatype)
+       firstcm := ""
+       for _, cm := range columeMetas {
+               name := cm.Name()
+               starrocksDatatype, ok := cm.ColumnType()
+               if !ok {
+                       return fmt.Errorf("Get [%s] ColumeType Failed", name)
+               }
+               column := fmt.Sprintf("%s %s", name, starrocksDatatype)
                columns = append(columns, column)
+               isPrimaryKey, ok := cm.PrimaryKey()
+               if isPrimaryKey && ok {
+                       if pks != "" {
+                               pks += ","
+                       }
+                       pks += name
+               }
+               if firstcm == "" {
+                       firstcm = name
+               }
        }
+
+       if pks == "" {
+               pks = firstcm
+       }
+
        if extra == "" {
-               extra = fmt.Sprintf(`engine=olap distributed by hash(%s) 
properties("replication_num" = "1")`, pk)
+               extra = fmt.Sprintf(`engine=olap distributed by hash(%s) 
properties("replication_num" = "1")`, pks)
        }
        tableSql := fmt.Sprintf(`create table if not exists %s ( %s ) %s`, 
starrocksTable, strings.Join(columns, ","), extra)
        c.GetLogger().Info(tableSql)
        _, err = starrocks.Exec(tableSql)
        return err
 }
+
 func loadData(starrocks *sql.DB, c core.SubTaskContext, starrocksTable string, 
table string, db dal.Dal, config *StarRocksConfig) error {
        offset := 0
        starrocksTmpTable := starrocksTable + "_tmp"
diff --git a/plugins/starrocks/utils.go b/utils/structfield.go
similarity index 57%
rename from plugins/starrocks/utils.go
rename to utils/structfield.go
index 01415af1..4df76282 100644
--- a/plugins/starrocks/utils.go
+++ b/utils/structfield.go
@@ -14,22 +14,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 
express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
-package main
 
-import "strings"
+package utils
 
-func getDataType(dataType string) string {
-       starrocksDatatype := dataType
-       if strings.HasPrefix(dataType, "varchar") {
-               starrocksDatatype = "string"
-       } else if strings.HasPrefix(dataType, "datetime") {
-               starrocksDatatype = "datetime"
-       } else if strings.HasPrefix(dataType, "bigint") {
-               starrocksDatatype = "bigint"
-       } else if dataType == "longtext" || dataType == "text" || dataType == 
"longblob" {
-               starrocksDatatype = "string"
-       } else if dataType == "tinyint(1)" {
-               starrocksDatatype = "boolean"
+import (
+       "reflect"
+)
+
+// WalkFiled get the field data by tag
+func WalkFields(t reflect.Type, filter func(field *reflect.StructField) bool) 
(f []reflect.StructField) {
+       if t.Kind() == reflect.Ptr {
+               t = t.Elem()
+       }
+       for i := 0; i < t.NumField(); i++ {
+               field := t.Field(i)
+               if field.Type.Kind() == reflect.Struct {
+                       f = append(f, WalkFields(field.Type, filter)...)
+               } else {
+                       if filter == nil {
+                               f = append(f, field)
+                       } else if filter(&field) {
+                               f = append(f, field)
+                       }
+               }
        }
-       return starrocksDatatype
+       return f
 }

Reply via email to