Skip to content

Commit

Permalink
add more wrap
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxulai committed Mar 23, 2024
1 parent 229e6fe commit 8b0cc50
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 8 deletions.
53 changes: 53 additions & 0 deletions alter_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package sqls

import (
"fmt"
"strings"
)

type alterTableBuilder struct {
builder *sqlBuilder
table string
addConstraint map[string]string
}

func newAlterTableBuilder() *alterTableBuilder {
builder := &alterTableBuilder{}
builder.builder = newSqlBuilder()
builder.addConstraint =make(map[string]string)
return builder
}

func ALTER_TABLE(table string) *alterTableBuilder {
s := newAlterTableBuilder()
s.table = table
return s
}

func (s *alterTableBuilder) ADD_CONSTRAINT(name string, options string) *alterTableBuilder {
s.addConstraint[name] = options
return s
}

func (s *alterTableBuilder) Param(v any) string {
return s.builder.Param(v)
}

func (s *alterTableBuilder) Params() []any {
return s.builder.Params(s.String())
}

func (s *alterTableBuilder) String() string {
var sqlString string
sqlString += s.builder.join("ALTER TABLE", "", []string{s.table}, "", "")

if len(s.addConstraint) > 0 {
var addConstraint []string
for name, options := range s.addConstraint {
addConstraint = append(addConstraint, fmt.Sprintf("ADD CONSTRAINT %s %s", name, options))
}
sqlString += s.builder.join("", "", addConstraint, ", ", "")
}

return strings.Trim(sqlString, " ")
}
18 changes: 18 additions & 0 deletions alter_table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package sqls

import (
"testing"
)

func TestAlterTableAddConstraint(t *testing.T) {
s := ALTER_TABLE("TABLE")
s.ADD_CONSTRAINT("test", "UNIQUE (column1, column2)")
s.ADD_CONSTRAINT("test2", "PRIMARY KEY (column1, column2)")

result := s.String()
expected := "ALTER TABLE TABLE ADD CONSTRAINT test UNIQUE (column1, column2), ADD CONSTRAINT test2 PRIMARY KEY (column1, column2)"

if result != expected {
t.Errorf("Case1() 返回值为 %s,期望值为 %s", result, expected)
}
}
11 changes: 11 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"regexp"
"slices"
"sort"
)

type sqlBuilder struct {
Expand Down Expand Up @@ -53,13 +54,23 @@ func (s *sqlBuilder) Param(v any) string {
return key
}

// 目前仅支持 postgresql
func (s *sqlBuilder) Params(sql string) []any {
result := []any{}
matches := paramRegexp.FindAllString(sql, -1)
sort.Strings(matches)

for _, match := range matches {
result = append(result, s.params[match])
}

return result
}

func (s *sqlBuilder) Clone() *sqlBuilder {
ns := &sqlBuilder{}
for key, value := range ns.params {
ns.params[key] = value
}
return ns
}
26 changes: 23 additions & 3 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,31 @@ import (

func TestCommonParamRegexp(t *testing.T) {
caseText := "$123"
result := paramRegexp.FindAllStringSubmatch(caseText, -1)
result := paramRegexp.FindAllString(caseText, -1)

expected := []string{"$123"}

if slices.Compare[[]string](result[0], expected) != 0 {
t.Errorf("FindAllString(%s) 返回值为 %v,期望值为 %v", caseText, result[0], expected)
if slices.Compare[[]string](result, expected) != 0 {
t.Errorf("FindAllString(%s) 返回值为 %v,期望值为 %v", caseText, result, expected)
}
}

func TestSimplePostgresqlParam(t *testing.T) {
builder := newSqlBuilder()

caseText := builder.Param(1) + " " + builder.Param(2) + " " + builder.Param(3)
result := paramRegexp.FindAllString(caseText, -1)

expected := []string{"$1", "$2", "$3"}

if slices.Compare[[]string](result, expected) != 0 {
t.Errorf("FindAllString(%s) 返回值为 %v,期望值为 %v", caseText, result, expected)
}

params := builder.Params(caseText)
paramSlice := []int{params[0].(int), params[1].(int), params[2].(int)}
expectedParams := []int{1, 2, 3}
if slices.Compare[[]int](paramSlice, expectedParams) != 0 {
t.Errorf("Params(%s) 返回值为 %v,期望值为 %v", caseText, result, expected)
}
}
8 changes: 7 additions & 1 deletion create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
type createTableStatement struct {
table string
columns []string
options []string
ifNotExists bool
}

Expand Down Expand Up @@ -38,6 +39,11 @@ func (s *createTableBuilder) COLUMN(v string) *createTableBuilder {
return s
}

func (s *createTableBuilder) OPTIONS(v string) *createTableBuilder {
s.statement.columns = append(s.statement.columns, v)
return s
}

func (s *createTableBuilder) Param(v any) string {
return s.builder.Param(v)
}
Expand All @@ -50,7 +56,7 @@ func (s *createTableBuilder) String() string {
}

sqlString += s.builder.join(keyword, "", []string{s.statement.table}, "", "")
sqlString += s.builder.join("", "(", s.statement.columns, ", ", ")")
sqlString += s.builder.join("", "(", append(append([]string{}, s.statement.columns...), s.statement.options...), ", ", ")")
return strings.Trim(sqlString, " ")
}

Expand Down
24 changes: 21 additions & 3 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
)

type insertStatement struct {
table []string
columns []string
values []string
table []string
columns []string
values []string
onConflict []string
doUpdateSet []string
}

type insertBuilder struct {
Expand Down Expand Up @@ -37,6 +39,16 @@ func (s *insertBuilder) VALUES(key string, value string) *insertBuilder {
return s
}

func (s *insertBuilder) ON_CONFLICT(v string) *insertBuilder {
s.statement.onConflict = append(s.statement.onConflict, v)
return s
}

func (s *insertBuilder) DO_UPDATE_SET(key string, value string) *insertBuilder {
s.statement.doUpdateSet = append(s.statement.onConflict, key+"="+value)
return s
}

func (s *insertBuilder) Param(v any) string {
return s.builder.Param(v)
}
Expand All @@ -46,6 +58,12 @@ func (s *insertBuilder) String() string {
sqlString += s.builder.join("INSERT INTO", "", s.statement.table, "", "")
sqlString += s.builder.join("", "(", s.statement.columns, ", ", ")")
sqlString += s.builder.join("VALUES", "(", s.statement.values, ", ", ")")

if len(s.statement.onConflict) > 0 && len(s.statement.doUpdateSet) > 0 {
sqlString += s.builder.join("ON CONFLICT", "(", s.statement.onConflict, ", ", ")")
sqlString += s.builder.join("DO UPDATE SET", "", s.statement.onConflict, ", ", "")
}

return strings.Trim(sqlString, " ")
}

Expand Down
6 changes: 5 additions & 1 deletion insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ func TestSimpleInsertStatement(t *testing.T) {
s.VALUES("a", s.Param(1))
s.VALUES("b", s.Param(1))
s.VALUES("c", s.Param(1))
s.ON_CONFLICT("a, b")
s.ON_CONFLICT("c")
s.DO_UPDATE_SET("a","1")
s.DO_UPDATE_SET("b","2")

result := s.String()
expected := "INSERT INTO PERSON (a, b, c) VALUES ($1, $2, $3)"
expected := "INSERT INTO PERSON (a, b, c) VALUES ($1, $2, $3) ON CONFLICT (a, b, c) DO UPDATE SET a, b, c"

if result != expected {
t.Errorf("Case1() 返回值为 %s,期望值为 %s", result, expected)
Expand Down

0 comments on commit 8b0cc50

Please sign in to comment.