You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by mi...@apache.org on 2023/04/13 08:53:55 UTC

[shardingsphere-on-cloud] branch main updated: feat: Add covert AST to string method (#310)

This is an automated email from the ASF dual-hosted git repository.

miaoliyao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/shardingsphere-on-cloud.git


The following commit(s) were added to refs/heads/main by this push:
     new 58441dc  feat: Add covert AST to string method (#310)
58441dc is described below

commit 58441dc1bf008babce0786afbf8ce428d8ff8c0b
Author: Jack <87...@qq.com>
AuthorDate: Thu Apr 13 16:53:50 2023 +0800

    feat: Add covert AST to string method (#310)
    
    * feat: add AST ToString() method
    
    Signed-off-by: wangbo <wa...@sphere-ex.com>
    
    * chore: add license
    
    Signed-off-by: wangbo <wa...@sphere-ex.com>
    
    ---------
    
    Signed-off-by: wangbo <wa...@sphere-ex.com>
    Co-authored-by: wangbo <wa...@sphere-ex.com>
---
 shardingsphere-operator/pkg/distsql/ast/rdl_ast.go | 245 ++++++++++++++++++++-
 .../pkg/distsql/visitor/distsql_test.go            |  55 +++++
 .../pkg/distsql/visitor/rdl_visitor.go             |  15 +-
 .../pkg/distsql/visitor/visitor_suite_test.go      |  30 +++
 4 files changed, 336 insertions(+), 9 deletions(-)

diff --git a/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go b/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go
index 2491ed2..67547ec 100644
--- a/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go
+++ b/shardingsphere-operator/pkg/distsql/ast/rdl_ast.go
@@ -17,15 +17,35 @@
 
 package ast
 
+import (
+	"fmt"
+	"strings"
+)
+
 // Define RDL AST
 type CreateEncryptRule struct {
-	Create                   string
-	Encrypt                  string
-	EncryptName              string
 	IfNotExists              *IfNotExists
+	EncryptRuleDefinition    *EncryptRuleDefinition
 	AllEncryptRuleDefinition []*EncryptRuleDefinition
 }
 
+func (createEncryptRule *CreateEncryptRule) ToString() string {
+	var ifNotExists string
+	var allEncryptRuleDefinitionList []string
+	if createEncryptRule.IfNotExists != nil {
+		ifNotExists = createEncryptRule.IfNotExists.ToString()
+	}
+
+	if createEncryptRule.AllEncryptRuleDefinition != nil {
+		for _, encryptRuleDefinition := range createEncryptRule.AllEncryptRuleDefinition {
+			if encryptRuleDefinition != nil {
+				allEncryptRuleDefinitionList = append(allEncryptRuleDefinitionList, encryptRuleDefinition.ToString())
+			}
+		}
+	}
+	return fmt.Sprintf("CREATE ENCRYPT RULE%s %s;", ifNotExists, strings.Join(allEncryptRuleDefinitionList, ","))
+}
+
 type AlterEncryptRule struct {
 	EncryptRuleDefinition []*EncryptRuleDefinition
 }
@@ -50,12 +70,48 @@ func (dropEncryptRule *DropEncryptRule) ToString() string {
 type EncryptRuleDefinition struct {
 	TableName                  *CommonIdentifier
 	ResourceDefinition         *ResourceDefinition
+	EncryptColumnDefinition    *EncryptColumnDefinition
 	AllEncryptColumnDefinition []*EncryptColumnDefinition
 	QueryWithCipherColumn      *QueryWithCipherColumn
 }
 
 func (encryptRuleDefinition *EncryptRuleDefinition) ToString() string {
-	return ""
+	var (
+		tableName                  string
+		resourceDefinition         string
+		queryWithCipherColumn      string
+		encryptColumnDefinition    string
+		allEncryptColumnDefinition []string
+	)
+
+	if encryptRuleDefinition.TableName != nil {
+		tableName = encryptRuleDefinition.TableName.ToString()
+	}
+
+	if encryptRuleDefinition.ResourceDefinition != nil {
+		resourceDefinition = encryptRuleDefinition.ResourceDefinition.ToString()
+	}
+
+	if encryptRuleDefinition.EncryptColumnDefinition != nil {
+		encryptColumnDefinition = encryptRuleDefinition.EncryptColumnDefinition.ToString()
+	}
+
+	if encryptRuleDefinition.AllEncryptColumnDefinition != nil {
+		for _, rd := range encryptRuleDefinition.AllEncryptColumnDefinition {
+			allEncryptColumnDefinition = append(allEncryptColumnDefinition, rd.ToString())
+		}
+	}
+
+	if encryptRuleDefinition.QueryWithCipherColumn != nil {
+		queryWithCipherColumn = fmt.Sprintf(",QUERY_WITH_CIPHER_COLUMN=%s", encryptRuleDefinition.QueryWithCipherColumn.ToString())
+	}
+
+	return fmt.Sprintf("%s (%sCOLUMNS(%s%s)%s)",
+		tableName,
+		resourceDefinition,
+		encryptColumnDefinition,
+		strings.Join(allEncryptColumnDefinition, ","),
+		queryWithCipherColumn)
 }
 
 type IfNotExists struct {
@@ -63,13 +119,20 @@ type IfNotExists struct {
 }
 
 func (ifNotExists IfNotExists) ToString() string {
-	return ""
+	return ifNotExists.IfNotExists
 }
 
 type ResourceDefinition struct {
 	ResourceName *CommonIdentifier
 }
 
+func (resourceDefinition *ResourceDefinition) ToString() string {
+	if resourceDefinition.ResourceName != nil {
+		return fmt.Sprintf("RESOURCE=%s", resourceDefinition.ResourceName.ToString())
+	}
+	return ""
+}
+
 type EncryptColumnDefinition struct {
 	ColumnDefinition              *ColumnDefinition
 	PlainColumnDefinition         *PlainColumnDefinition
@@ -82,82 +145,254 @@ type EncryptColumnDefinition struct {
 	QueryWithCipherColumn         *QueryWithCipherColumn
 }
 
+func (encryptColumnDefinition *EncryptColumnDefinition) ToString() (sql string) {
+	var (
+		plainColumnDefinition         string
+		assistedQueryColumnDefinition string
+		likeQueryColumnDefinition     string
+		assistedQueryAlgorithm        string
+		likeQueryAlgorithm            string
+		queryWithCipherColumn         string
+	)
+
+	if encryptColumnDefinition.PlainColumnDefinition != nil {
+		sql += encryptColumnDefinition.PlainColumnDefinition.ToString()
+		plainColumnDefinition = fmt.Sprintf(",%s", encryptColumnDefinition.PlainColumnDefinition.ToString())
+	}
+
+	if encryptColumnDefinition.CipherColumnDefinition != nil {
+		sql += encryptColumnDefinition.CipherColumnDefinition.ToString()
+	}
+
+	if encryptColumnDefinition.AssistedQueryColumnDefinition != nil {
+		assistedQueryColumnDefinition = fmt.Sprintf(",%s", encryptColumnDefinition.AssistedQueryColumnDefinition.ToString())
+	}
+
+	if encryptColumnDefinition.LikeQueryAlgorithm != nil {
+		likeQueryColumnDefinition = fmt.Sprintf(",%s", encryptColumnDefinition.LikeQueryAlgorithm.ToString())
+	}
+
+	if encryptColumnDefinition.AssistedQueryAlgorithm != nil {
+		assistedQueryAlgorithm = fmt.Sprintf(",%s", encryptColumnDefinition.AssistedQueryAlgorithm.ToString())
+	}
+
+	if encryptColumnDefinition.LikeQueryAlgorithm != nil {
+		likeQueryAlgorithm = fmt.Sprintf(",%s", encryptColumnDefinition.LikeQueryAlgorithm.ToString())
+	}
+
+	if encryptColumnDefinition.QueryWithCipherColumn != nil {
+		queryWithCipherColumn = fmt.Sprintf(",QUERY_WITH_CIPHER_COLUMN=%s", encryptColumnDefinition.QueryWithCipherColumn.ToString())
+	}
+
+	return fmt.Sprintf("(%s%s,%s%s%s,%s%s%s%s)",
+		encryptColumnDefinition.ColumnDefinition.ToString(),
+		plainColumnDefinition,
+		encryptColumnDefinition.CipherColumnDefinition.ToString(),
+		assistedQueryColumnDefinition,
+		likeQueryColumnDefinition,
+		encryptColumnDefinition.EncryptAlgorithm.ToString(),
+		assistedQueryAlgorithm,
+		likeQueryAlgorithm,
+		queryWithCipherColumn)
+}
+
 type ColumnDefinition struct {
 	ColumnName *CommonIdentifier
 	DataType   *DataType
 }
 
+func (columnDefinition *ColumnDefinition) ToString() (sql string) {
+	var dataType string
+	if columnDefinition.DataType != nil {
+		dataType = fmt.Sprintf(",DATA_TYP=%s", columnDefinition.DataType.ToString())
+	}
+
+	sql = fmt.Sprintf("NAME=%s%s", columnDefinition.ColumnName.ToString(), dataType)
+
+	return
+}
+
 type PlainColumnDefinition struct {
 	PlainColumnName *CommonIdentifier
 	DataType        *DataType
 }
 
+func (plainColumnDefinition *PlainColumnDefinition) ToString() (sql string) {
+	if plainColumnDefinition.PlainColumnName != nil {
+		sql += fmt.Sprintf("PLAIN=%s", plainColumnDefinition.PlainColumnName.ToString())
+	}
+	if plainColumnDefinition.DataType != nil {
+		sql += plainColumnDefinition.DataType.ToString()
+	}
+	return
+}
+
 type CipherColumnDefinition struct {
 	CipherColumnName *CommonIdentifier
 	DataType         *DataType
 }
 
+func (cipherColumnDefinition *CipherColumnDefinition) ToString() string {
+	var dataType string
+	if cipherColumnDefinition.DataType != nil {
+		dataType = fmt.Sprintf(",CIPHER_DATA_TYPE=%s", dataType)
+	}
+	return fmt.Sprintf("CIPHER=%s%s", cipherColumnDefinition.CipherColumnName.ToString(), dataType)
+}
+
 type AssistedQueryColumnDefinition struct {
 	AssistedQueryColumnName *CommonIdentifier
 	DataType                *DataType
 }
 
+func (assistedQueryColumnDefinition *AssistedQueryColumnDefinition) ToString() string {
+	var dataType string
+	if assistedQueryColumnDefinition.DataType != nil {
+		dataType = fmt.Sprintf(",ASSISTED_QUERY_DATA_TYPE=%s", assistedQueryColumnDefinition.DataType.ToString())
+	}
+	return fmt.Sprintf("ASSISTED_QUERY_COLUMN=%s%s", assistedQueryColumnDefinition.AssistedQueryColumnName.ToString(), dataType)
+}
+
 type LikeQueryColumnDefinition struct {
 	LikeQueryColumnName *CommonIdentifier
 	DataType            *DataType
 }
 
+func (likeQueryColumnDefinition *LikeQueryColumnDefinition) ToString() string {
+	var dataType string
+	if likeQueryColumnDefinition.DataType != nil {
+		dataType = fmt.Sprintf("COMMA_ LIKE_QUERY_DATA_TYPE=%s", likeQueryColumnDefinition.DataType.ToString())
+	}
+	return fmt.Sprintf("LIKE_QUERY_COLUMN=%s%s", likeQueryColumnDefinition.LikeQueryColumnName.ToString(), dataType)
+}
+
 type EncryptAlgorithm struct {
 	AlgorithmDefinition *AlgorithmDefinition
 }
 
+func (encryptAlgorithm *EncryptAlgorithm) ToString() string {
+	return fmt.Sprintf("ENCRYPT_ALGORITHM(%s)", encryptAlgorithm.AlgorithmDefinition.ToString())
+}
+
 type AssistedQueryAlgorithm struct {
 	AlgorithmDefinition *AlgorithmDefinition
 }
 
+func (assistedQueryAlgorithm *AssistedQueryAlgorithm) ToString() string {
+	return assistedQueryAlgorithm.AlgorithmDefinition.ToString()
+}
+
 type AlgorithmDefinition struct {
 	AlgorithmTypeName    *AlgorithmTypeName
 	PropertiesDefinition *PropertiesDefinition
 }
 
+func (algorithmDefinition AlgorithmDefinition) ToString() string {
+	var propertiesDefinition string
+
+	if algorithmDefinition.PropertiesDefinition != nil {
+		propertiesDefinition = fmt.Sprintf(",%s", algorithmDefinition.PropertiesDefinition.ToString())
+	}
+
+	return fmt.Sprintf("TYPE(NAME=%s%s)", algorithmDefinition.AlgorithmTypeName.ToString(), propertiesDefinition)
+}
+
 type PropertiesDefinition struct {
 	Properties *Properties
 }
 
+func (propertiesDefinition *PropertiesDefinition) ToString() string {
+	if propertiesDefinition.Properties != nil {
+		return fmt.Sprintf("PROPERTIES(%s)", propertiesDefinition.Properties.ToString())
+	}
+	return ""
+}
+
 type Properties struct {
 	Properties []*Property
 }
 
+func (properties *Properties) ToString() (sql string) {
+	for _, property := range properties.Properties {
+		sql += property.ToString()
+	}
+	return
+}
+
 type LikeQueryAlgorithm struct {
 	AlgorithmDefinition *AlgorithmDefinition
 }
 
+func (likeQueryAlgorithm *LikeQueryAlgorithm) ToString() (sql string) {
+	if likeQueryAlgorithm.AlgorithmDefinition != nil {
+		sql += likeQueryAlgorithm.ToString()
+	}
+	return
+}
+
 type QueryWithCipherColumn struct {
 	QueryWithCipherColumn string
 }
 
+func (queryWithAlgorithm *QueryWithCipherColumn) ToString() string {
+	return queryWithAlgorithm.QueryWithCipherColumn
+}
+
 type CommonIdentifier struct {
 	Identifier string
 }
 
+func (commonIdentifier *CommonIdentifier) ToString() string {
+	return commonIdentifier.Identifier
+}
+
 type Property struct {
 	Key     string
 	Literal *Literal
 }
 
+func (property *Property) ToString() (sql string) {
+	if property.Literal != nil {
+		sql = fmt.Sprintf("%s=%s", property.Key, property.Literal.ToString())
+	}
+	return
+}
+
 type Literal struct {
 	Literal string
 }
 
+func (literal *Literal) ToString() string {
+	return literal.Literal
+}
+
 type BuildinAlgorithmTypeName struct {
 	AlgorithmTypeName string
 }
 
+func (buildinAlgorithmTypeName *BuildinAlgorithmTypeName) ToString() string {
+	return buildinAlgorithmTypeName.AlgorithmTypeName
+}
+
 type DataType struct {
 	String string
 }
 
+func (dataType *DataType) ToString() string {
+	return dataType.String
+}
+
 type AlgorithmTypeName struct {
 	BuildinAlgorithmTypeName *BuildinAlgorithmTypeName
 	String                   string
 }
+
+func (algorithmTypeName *AlgorithmTypeName) ToString() string {
+	switch {
+	case algorithmTypeName.BuildinAlgorithmTypeName != nil:
+		return algorithmTypeName.BuildinAlgorithmTypeName.ToString()
+	case algorithmTypeName.String != "":
+		return algorithmTypeName.String
+	}
+	return ""
+}
diff --git a/shardingsphere-operator/pkg/distsql/visitor/distsql_test.go b/shardingsphere-operator/pkg/distsql/visitor/distsql_test.go
new file mode 100644
index 0000000..0f5d683
--- /dev/null
+++ b/shardingsphere-operator/pkg/distsql/visitor/distsql_test.go
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package visitor
+
+import (
+	"github.com/antlr/antlr4/runtime/Go/antlr"
+	"github.com/apache/shardingsphere-on-cloud/shardingsphere-operator/pkg/distsql/ast"
+	parser "github.com/apache/shardingsphere-on-cloud/shardingsphere-operator/pkg/distsql/visitor_parser/encrypt"
+	. "github.com/onsi/ginkgo/v2"
+	. "github.com/onsi/gomega"
+)
+
+var _ = Describe("Distsql", func() {
+	var (
+		encryptDistSQL = "CREATE ENCRYPT RULE t_encrypt (COLUMNS((NAME=user_id,PLAIN=user_plain,CIPHER=user_cipher,ENCRYPT_ALGORITHM(TYPE(NAME='AES',PROPERTIES('aes-key-value'='123456abc')))),(NAME=order_id,CIPHER=order_cipher,ENCRYPT_ALGORITHM(TYPE(NAME='MD5')))),QUERY_WITH_CIPHER_COLUMN=true);"
+		visitor        = Visitor{}
+		ast            = &ast.CreateEncryptRule{}
+	)
+
+	BeforeEach(func() {
+		inputStream := antlr.NewInputStream(encryptDistSQL)
+		lexer := parser.NewRDLStatementLexer(inputStream)
+		tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
+		distSQLParser := parser.NewRDLStatementParser(tokens)
+		createEncryptRule := distSQLParser.CreateEncryptRule()
+		ast = visitor.VisitCreateEncryptRule(createEncryptRule.(*parser.CreateEncryptRuleContext))
+	})
+
+	Context("parse distSQL to AST", func() {
+		It("should encrypt distSQL parse correctly", func() {
+			Expect(ast.AllEncryptRuleDefinition[0].TableName.Identifier).To(Equal("t_encrypt"))
+		})
+	})
+
+	Context("covert distSQL AST to string", func() {
+		It("should encrypt distsql parse correctly", func() {
+			Expect(ast.AllEncryptRuleDefinition[0].TableName.ToString()).To(Equal("t_encrypt"))
+		})
+	})
+})
diff --git a/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go b/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go
index 3adaeaa..36a4662 100644
--- a/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go
+++ b/shardingsphere-operator/pkg/distsql/visitor/rdl_visitor.go
@@ -30,17 +30,20 @@ type Visitor struct {
 
 func (v *Visitor) VisitCreateEncryptRule(ctx *parser.CreateEncryptRuleContext) *ast.CreateEncryptRule {
 	stmt := &ast.CreateEncryptRule{}
-	stmt.Create = ctx.CREATE().GetText()
-	stmt.Encrypt = ctx.ENCRYPT().GetText()
-	stmt.EncryptName = ctx.RULE().GetText()
 
 	if ctx.IfNotExists() != nil {
 		stmt.IfNotExists = v.VisitIfNotExists(ctx.IfNotExists().(*parser.IfNotExistsContext))
 	}
 
+	if ctx.EncryptRuleDefinition(0) != nil {
+		stmt.EncryptRuleDefinition = v.VisitEncryptRuleDefinition(ctx.EncryptRuleDefinition(0).(*parser.EncryptRuleDefinitionContext))
+	}
+
 	if ctx.AllEncryptRuleDefinition() != nil {
 		for _, r := range ctx.AllEncryptRuleDefinition() {
-			stmt.AllEncryptRuleDefinition = append(stmt.AllEncryptRuleDefinition, v.VisitEncryptRuleDefinition(r.(*parser.EncryptRuleDefinitionContext)))
+			if r != nil {
+				stmt.AllEncryptRuleDefinition = append(stmt.AllEncryptRuleDefinition, v.VisitEncryptRuleDefinition(r.(*parser.EncryptRuleDefinitionContext)))
+			}
 		}
 	}
 
@@ -95,6 +98,10 @@ func (v *Visitor) VisitEncryptRuleDefinition(ctx *parser.EncryptRuleDefinitionCo
 		stmt.ResourceDefinition = v.VisitResourceDefinition(ctx.ResourceDefinition().(*parser.ResourceDefinitionContext))
 	}
 
+	// if ctx.EncryptColumnDefinition(0) != nil {
+	// 	stmt.EncryptColumnDefinition = v.VisitEncryptColumnDefinition(ctx.EncryptColumnDefinition(0).(*parser.EncryptColumnDefinitionContext))
+	// }
+
 	if ctx.AllEncryptColumnDefinition() != nil {
 		for _, column := range ctx.AllEncryptColumnDefinition() {
 			stmt.AllEncryptColumnDefinition = append(stmt.AllEncryptColumnDefinition, v.VisitEncryptColumnDefinition(column.(*parser.EncryptColumnDefinitionContext)))
diff --git a/shardingsphere-operator/pkg/distsql/visitor/visitor_suite_test.go b/shardingsphere-operator/pkg/distsql/visitor/visitor_suite_test.go
new file mode 100644
index 0000000..8cd3276
--- /dev/null
+++ b/shardingsphere-operator/pkg/distsql/visitor/visitor_suite_test.go
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package visitor
+
+import (
+	"testing"
+
+	. "github.com/onsi/ginkgo/v2"
+	. "github.com/onsi/gomega"
+)
+
+func TestVisitor(t *testing.T) {
+	RegisterFailHandler(Fail)
+	RunSpecs(t, "Visitor Suite")
+}