This is an automated email from the ASF dual-hosted git repository.

miaoliyao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/shardingsphere-on-cloud.git


The following commit(s) were added to refs/heads/main by this push:
     new 58441dc  feat: Add covert AST to string method (#310)
58441dc is described below

commit 58441dc1bf008babce0786afbf8ce428d8ff8c0b
Author: Jack <[email protected]>
AuthorDate: Thu Apr 13 16:53:50 2023 +0800

    feat: Add covert AST to string method (#310)
    
    * feat: add AST ToString() method
    
    Signed-off-by: wangbo <[email protected]>
    
    * chore: add license
    
    Signed-off-by: wangbo <[email protected]>
    
    ---------
    
    Signed-off-by: wangbo <[email protected]>
    Co-authored-by: wangbo <[email protected]>
---
 shardingsphere-operator/pkg/distsql/ast/rdl_ast.go | 245 ++++++++++++++++++++-
 .../pkg/distsql/visitor/distsql_test.go            |  55 +++++
 .../pkg/distsql/visitor/rdl_visitor.go             |  15 +-
 .../pkg/distsql/visitor/visitor_suite_test.go      |  30 +++
 4 files changed, 336 insertions(+), 9 deletions(-)

diff --git a/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go 
b/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go
index 2491ed2..67547ec 100644
--- a/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go
+++ b/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go
@@ -17,15 +17,35 @@
 
 package ast
 
+import (
+       "fmt"
+       "strings"
+)
+
 // Define RDL AST
 type CreateEncryptRule struct {
-       Create                   string
-       Encrypt                  string
-       EncryptName              string
        IfNotExists              *IfNotExists
+       EncryptRuleDefinition    *EncryptRuleDefinition
        AllEncryptRuleDefinition []*EncryptRuleDefinition
 }
 
+func (createEncryptRule *CreateEncryptRule) ToString() string {
+       var ifNotExists string
+       var allEncryptRuleDefinitionList []string
+       if createEncryptRule.IfNotExists != nil {
+               ifNotExists = createEncryptRule.IfNotExists.ToString()
+       }
+
+       if createEncryptRule.AllEncryptRuleDefinition != nil {
+               for _, encryptRuleDefinition := range 
createEncryptRule.AllEncryptRuleDefinition {
+                       if encryptRuleDefinition != nil {
+                               allEncryptRuleDefinitionList = 
append(allEncryptRuleDefinitionList, encryptRuleDefinition.ToString())
+                       }
+               }
+       }
+       return fmt.Sprintf("CREATE ENCRYPT RULE%s %s;", ifNotExists, 
strings.Join(allEncryptRuleDefinitionList, ","))
+}
+
 type AlterEncryptRule struct {
        EncryptRuleDefinition []*EncryptRuleDefinition
 }
@@ -50,12 +70,48 @@ func (dropEncryptRule *DropEncryptRule) ToString() string {
 type EncryptRuleDefinition struct {
        TableName                  *CommonIdentifier
        ResourceDefinition         *ResourceDefinition
+       EncryptColumnDefinition    *EncryptColumnDefinition
        AllEncryptColumnDefinition []*EncryptColumnDefinition
        QueryWithCipherColumn      *QueryWithCipherColumn
 }
 
 func (encryptRuleDefinition *EncryptRuleDefinition) ToString() string {
-       return ""
+       var (
+               tableName                  string
+               resourceDefinition         string
+               queryWithCipherColumn      string
+               encryptColumnDefinition    string
+               allEncryptColumnDefinition []string
+       )
+
+       if encryptRuleDefinition.TableName != nil {
+               tableName = encryptRuleDefinition.TableName.ToString()
+       }
+
+       if encryptRuleDefinition.ResourceDefinition != nil {
+               resourceDefinition = 
encryptRuleDefinition.ResourceDefinition.ToString()
+       }
+
+       if encryptRuleDefinition.EncryptColumnDefinition != nil {
+               encryptColumnDefinition = 
encryptRuleDefinition.EncryptColumnDefinition.ToString()
+       }
+
+       if encryptRuleDefinition.AllEncryptColumnDefinition != nil {
+               for _, rd := range 
encryptRuleDefinition.AllEncryptColumnDefinition {
+                       allEncryptColumnDefinition = 
append(allEncryptColumnDefinition, rd.ToString())
+               }
+       }
+
+       if encryptRuleDefinition.QueryWithCipherColumn != nil {
+               queryWithCipherColumn = 
fmt.Sprintf(",QUERY_WITH_CIPHER_COLUMN=%s", 
encryptRuleDefinition.QueryWithCipherColumn.ToString())
+       }
+
+       return fmt.Sprintf("%s (%sCOLUMNS(%s%s)%s)",
+               tableName,
+               resourceDefinition,
+               encryptColumnDefinition,
+               strings.Join(allEncryptColumnDefinition, ","),
+               queryWithCipherColumn)
 }
 
 type IfNotExists struct {
@@ -63,13 +119,20 @@ type IfNotExists struct {
 }
 
 func (ifNotExists IfNotExists) ToString() string {
-       return ""
+       return ifNotExists.IfNotExists
 }
 
 type ResourceDefinition struct {
        ResourceName *CommonIdentifier
 }
 
+func (resourceDefinition *ResourceDefinition) ToString() string {
+       if resourceDefinition.ResourceName != nil {
+               return fmt.Sprintf("RESOURCE=%s", 
resourceDefinition.ResourceName.ToString())
+       }
+       return ""
+}
+
 type EncryptColumnDefinition struct {
        ColumnDefinition              *ColumnDefinition
        PlainColumnDefinition         *PlainColumnDefinition
@@ -82,82 +145,254 @@ type EncryptColumnDefinition struct {
        QueryWithCipherColumn         *QueryWithCipherColumn
 }
 
+func (encryptColumnDefinition *EncryptColumnDefinition) ToString() (sql 
string) {
+       var (
+               plainColumnDefinition         string
+               assistedQueryColumnDefinition string
+               likeQueryColumnDefinition     string
+               assistedQueryAlgorithm        string
+               likeQueryAlgorithm            string
+               queryWithCipherColumn         string
+       )
+
+       if encryptColumnDefinition.PlainColumnDefinition != nil {
+               sql += encryptColumnDefinition.PlainColumnDefinition.ToString()
+               plainColumnDefinition = fmt.Sprintf(",%s", 
encryptColumnDefinition.PlainColumnDefinition.ToString())
+       }
+
+       if encryptColumnDefinition.CipherColumnDefinition != nil {
+               sql += encryptColumnDefinition.CipherColumnDefinition.ToString()
+       }
+
+       if encryptColumnDefinition.AssistedQueryColumnDefinition != nil {
+               assistedQueryColumnDefinition = fmt.Sprintf(",%s", 
encryptColumnDefinition.AssistedQueryColumnDefinition.ToString())
+       }
+
+       if encryptColumnDefinition.LikeQueryAlgorithm != nil {
+               likeQueryColumnDefinition = fmt.Sprintf(",%s", 
encryptColumnDefinition.LikeQueryAlgorithm.ToString())
+       }
+
+       if encryptColumnDefinition.AssistedQueryAlgorithm != nil {
+               assistedQueryAlgorithm = fmt.Sprintf(",%s", 
encryptColumnDefinition.AssistedQueryAlgorithm.ToString())
+       }
+
+       if encryptColumnDefinition.LikeQueryAlgorithm != nil {
+               likeQueryAlgorithm = fmt.Sprintf(",%s", 
encryptColumnDefinition.LikeQueryAlgorithm.ToString())
+       }
+
+       if encryptColumnDefinition.QueryWithCipherColumn != nil {
+               queryWithCipherColumn = 
fmt.Sprintf(",QUERY_WITH_CIPHER_COLUMN=%s", 
encryptColumnDefinition.QueryWithCipherColumn.ToString())
+       }
+
+       return fmt.Sprintf("(%s%s,%s%s%s,%s%s%s%s)",
+               encryptColumnDefinition.ColumnDefinition.ToString(),
+               plainColumnDefinition,
+               encryptColumnDefinition.CipherColumnDefinition.ToString(),
+               assistedQueryColumnDefinition,
+               likeQueryColumnDefinition,
+               encryptColumnDefinition.EncryptAlgorithm.ToString(),
+               assistedQueryAlgorithm,
+               likeQueryAlgorithm,
+               queryWithCipherColumn)
+}
+
 type ColumnDefinition struct {
        ColumnName *CommonIdentifier
        DataType   *DataType
 }
 
+func (columnDefinition *ColumnDefinition) ToString() (sql string) {
+       var dataType string
+       if columnDefinition.DataType != nil {
+               dataType = fmt.Sprintf(",DATA_TYP=%s", 
columnDefinition.DataType.ToString())
+       }
+
+       sql = fmt.Sprintf("NAME=%s%s", columnDefinition.ColumnName.ToString(), 
dataType)
+
+       return
+}
+
 type PlainColumnDefinition struct {
        PlainColumnName *CommonIdentifier
        DataType        *DataType
 }
 
+func (plainColumnDefinition *PlainColumnDefinition) ToString() (sql string) {
+       if plainColumnDefinition.PlainColumnName != nil {
+               sql += fmt.Sprintf("PLAIN=%s", 
plainColumnDefinition.PlainColumnName.ToString())
+       }
+       if plainColumnDefinition.DataType != nil {
+               sql += plainColumnDefinition.DataType.ToString()
+       }
+       return
+}
+
 type CipherColumnDefinition struct {
        CipherColumnName *CommonIdentifier
        DataType         *DataType
 }
 
+func (cipherColumnDefinition *CipherColumnDefinition) ToString() string {
+       var dataType string
+       if cipherColumnDefinition.DataType != nil {
+               dataType = fmt.Sprintf(",CIPHER_DATA_TYPE=%s", dataType)
+       }
+       return fmt.Sprintf("CIPHER=%s%s", 
cipherColumnDefinition.CipherColumnName.ToString(), dataType)
+}
+
 type AssistedQueryColumnDefinition struct {
        AssistedQueryColumnName *CommonIdentifier
        DataType                *DataType
 }
 
+func (assistedQueryColumnDefinition *AssistedQueryColumnDefinition) ToString() 
string {
+       var dataType string
+       if assistedQueryColumnDefinition.DataType != nil {
+               dataType = fmt.Sprintf(",ASSISTED_QUERY_DATA_TYPE=%s", 
assistedQueryColumnDefinition.DataType.ToString())
+       }
+       return fmt.Sprintf("ASSISTED_QUERY_COLUMN=%s%s", 
assistedQueryColumnDefinition.AssistedQueryColumnName.ToString(), dataType)
+}
+
 type LikeQueryColumnDefinition struct {
        LikeQueryColumnName *CommonIdentifier
        DataType            *DataType
 }
 
+func (likeQueryColumnDefinition *LikeQueryColumnDefinition) ToString() string {
+       var dataType string
+       if likeQueryColumnDefinition.DataType != nil {
+               dataType = fmt.Sprintf("COMMA_ LIKE_QUERY_DATA_TYPE=%s", 
likeQueryColumnDefinition.DataType.ToString())
+       }
+       return fmt.Sprintf("LIKE_QUERY_COLUMN=%s%s", 
likeQueryColumnDefinition.LikeQueryColumnName.ToString(), dataType)
+}
+
 type EncryptAlgorithm struct {
        AlgorithmDefinition *AlgorithmDefinition
 }
 
+func (encryptAlgorithm *EncryptAlgorithm) ToString() string {
+       return fmt.Sprintf("ENCRYPT_ALGORITHM(%s)", 
encryptAlgorithm.AlgorithmDefinition.ToString())
+}
+
 type AssistedQueryAlgorithm struct {
        AlgorithmDefinition *AlgorithmDefinition
 }
 
+func (assistedQueryAlgorithm *AssistedQueryAlgorithm) ToString() string {
+       return assistedQueryAlgorithm.AlgorithmDefinition.ToString()
+}
+
 type AlgorithmDefinition struct {
        AlgorithmTypeName    *AlgorithmTypeName
        PropertiesDefinition *PropertiesDefinition
 }
 
+func (algorithmDefinition AlgorithmDefinition) ToString() string {
+       var propertiesDefinition string
+
+       if algorithmDefinition.PropertiesDefinition != nil {
+               propertiesDefinition = fmt.Sprintf(",%s", 
algorithmDefinition.PropertiesDefinition.ToString())
+       }
+
+       return fmt.Sprintf("TYPE(NAME=%s%s)", 
algorithmDefinition.AlgorithmTypeName.ToString(), propertiesDefinition)
+}
+
 type PropertiesDefinition struct {
        Properties *Properties
 }
 
+func (propertiesDefinition *PropertiesDefinition) ToString() string {
+       if propertiesDefinition.Properties != nil {
+               return fmt.Sprintf("PROPERTIES(%s)", 
propertiesDefinition.Properties.ToString())
+       }
+       return ""
+}
+
 type Properties struct {
        Properties []*Property
 }
 
+func (properties *Properties) ToString() (sql string) {
+       for _, property := range properties.Properties {
+               sql += property.ToString()
+       }
+       return
+}
+
 type LikeQueryAlgorithm struct {
        AlgorithmDefinition *AlgorithmDefinition
 }
 
+func (likeQueryAlgorithm *LikeQueryAlgorithm) ToString() (sql string) {
+       if likeQueryAlgorithm.AlgorithmDefinition != nil {
+               sql += likeQueryAlgorithm.ToString()
+       }
+       return
+}
+
 type QueryWithCipherColumn struct {
        QueryWithCipherColumn string
 }
 
+func (queryWithAlgorithm *QueryWithCipherColumn) ToString() string {
+       return queryWithAlgorithm.QueryWithCipherColumn
+}
+
 type CommonIdentifier struct {
        Identifier string
 }
 
+func (commonIdentifier *CommonIdentifier) ToString() string {
+       return commonIdentifier.Identifier
+}
+
 type Property struct {
        Key     string
        Literal *Literal
 }
 
+func (property *Property) ToString() (sql string) {
+       if property.Literal != nil {
+               sql = fmt.Sprintf("%s=%s", property.Key, 
property.Literal.ToString())
+       }
+       return
+}
+
 type Literal struct {
        Literal string
 }
 
+func (literal *Literal) ToString() string {
+       return literal.Literal
+}
+
 type BuildinAlgorithmTypeName struct {
        AlgorithmTypeName string
 }
 
+func (buildinAlgorithmTypeName *BuildinAlgorithmTypeName) ToString() string {
+       return buildinAlgorithmTypeName.AlgorithmTypeName
+}
+
 type DataType struct {
        String string
 }
 
+func (dataType *DataType) ToString() string {
+       return dataType.String
+}
+
 type AlgorithmTypeName struct {
        BuildinAlgorithmTypeName *BuildinAlgorithmTypeName
        String                   string
 }
+
+func (algorithmTypeName *AlgorithmTypeName) ToString() string {
+       switch {
+       case algorithmTypeName.BuildinAlgorithmTypeName != nil:
+               return algorithmTypeName.BuildinAlgorithmTypeName.ToString()
+       case algorithmTypeName.String != "":
+               return algorithmTypeName.String
+       }
+       return ""
+}
diff --git a/shardingsphere-operator/pkg/distsql/visitor/distsql_test.go 
b/shardingsphere-operator/pkg/distsql/visitor/distsql_test.go
new file mode 100644
index 0000000..0f5d683
--- /dev/null
+++ b/shardingsphere-operator/pkg/distsql/visitor/distsql_test.go
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package visitor
+
+import (
+       "github.com/antlr/antlr4/runtime/Go/antlr"
+       
"github.com/apache/shardingsphere-on-cloud/shardingsphere-operator/pkg/distsql/ast"
+       parser 
"github.com/apache/shardingsphere-on-cloud/shardingsphere-operator/pkg/distsql/visitor_parser/encrypt"
+       . "github.com/onsi/ginkgo/v2"
+       . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Distsql", func() {
+       var (
+               encryptDistSQL = "CREATE ENCRYPT RULE t_encrypt 
(COLUMNS((NAME=user_id,PLAIN=user_plain,CIPHER=user_cipher,ENCRYPT_ALGORITHM(TYPE(NAME='AES',PROPERTIES('aes-key-value'='123456abc')))),(NAME=order_id,CIPHER=order_cipher,ENCRYPT_ALGORITHM(TYPE(NAME='MD5')))),QUERY_WITH_CIPHER_COLUMN=true);"
+               visitor        = Visitor{}
+               ast            = &ast.CreateEncryptRule{}
+       )
+
+       BeforeEach(func() {
+               inputStream := antlr.NewInputStream(encryptDistSQL)
+               lexer := parser.NewRDLStatementLexer(inputStream)
+               tokens := antlr.NewCommonTokenStream(lexer, 
antlr.TokenDefaultChannel)
+               distSQLParser := parser.NewRDLStatementParser(tokens)
+               createEncryptRule := distSQLParser.CreateEncryptRule()
+               ast = 
visitor.VisitCreateEncryptRule(createEncryptRule.(*parser.CreateEncryptRuleContext))
+       })
+
+       Context("parse distSQL to AST", func() {
+               It("should encrypt distSQL parse correctly", func() {
+                       
Expect(ast.AllEncryptRuleDefinition[0].TableName.Identifier).To(Equal("t_encrypt"))
+               })
+       })
+
+       Context("covert distSQL AST to string", func() {
+               It("should encrypt distsql parse correctly", func() {
+                       
Expect(ast.AllEncryptRuleDefinition[0].TableName.ToString()).To(Equal("t_encrypt"))
+               })
+       })
+})
diff --git a/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go 
b/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go
index 3adaeaa..36a4662 100644
--- a/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go
+++ b/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go
@@ -30,17 +30,20 @@ type Visitor struct {
 
 func (v *Visitor) VisitCreateEncryptRule(ctx *parser.CreateEncryptRuleContext) 
*ast.CreateEncryptRule {
        stmt := &ast.CreateEncryptRule{}
-       stmt.Create = ctx.CREATE().GetText()
-       stmt.Encrypt = ctx.ENCRYPT().GetText()
-       stmt.EncryptName = ctx.RULE().GetText()
 
        if ctx.IfNotExists() != nil {
                stmt.IfNotExists = 
v.VisitIfNotExists(ctx.IfNotExists().(*parser.IfNotExistsContext))
        }
 
+       if ctx.EncryptRuleDefinition(0) != nil {
+               stmt.EncryptRuleDefinition = 
v.VisitEncryptRuleDefinition(ctx.EncryptRuleDefinition(0).(*parser.EncryptRuleDefinitionContext))
+       }
+
        if ctx.AllEncryptRuleDefinition() != nil {
                for _, r := range ctx.AllEncryptRuleDefinition() {
-                       stmt.AllEncryptRuleDefinition = 
append(stmt.AllEncryptRuleDefinition, 
v.VisitEncryptRuleDefinition(r.(*parser.EncryptRuleDefinitionContext)))
+                       if r != nil {
+                               stmt.AllEncryptRuleDefinition = 
append(stmt.AllEncryptRuleDefinition, 
v.VisitEncryptRuleDefinition(r.(*parser.EncryptRuleDefinitionContext)))
+                       }
                }
        }
 
@@ -95,6 +98,10 @@ func (v *Visitor) VisitEncryptRuleDefinition(ctx 
*parser.EncryptRuleDefinitionCo
                stmt.ResourceDefinition = 
v.VisitResourceDefinition(ctx.ResourceDefinition().(*parser.ResourceDefinitionContext))
        }
 
+       // if ctx.EncryptColumnDefinition(0) != nil {
+       //      stmt.EncryptColumnDefinition = 
v.VisitEncryptColumnDefinition(ctx.EncryptColumnDefinition(0).(*parser.EncryptColumnDefinitionContext))
+       // }
+
        if ctx.AllEncryptColumnDefinition() != nil {
                for _, column := range ctx.AllEncryptColumnDefinition() {
                        stmt.AllEncryptColumnDefinition = 
append(stmt.AllEncryptColumnDefinition, 
v.VisitEncryptColumnDefinition(column.(*parser.EncryptColumnDefinitionContext)))
diff --git a/shardingsphere-operator/pkg/distsql/visitor/visitor_suite_test.go 
b/shardingsphere-operator/pkg/distsql/visitor/visitor_suite_test.go
new file mode 100644
index 0000000..8cd3276
--- /dev/null
+++ b/shardingsphere-operator/pkg/distsql/visitor/visitor_suite_test.go
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package visitor
+
+import (
+       "testing"
+
+       . "github.com/onsi/ginkgo/v2"
+       . "github.com/onsi/gomega"
+)
+
+func TestVisitor(t *testing.T) {
+       RegisterFailHandler(Fail)
+       RunSpecs(t, "Visitor Suite")
+}

Reply via email to