Skip to content

Commit

Permalink
🐛 fix logic of parsing multiple columns (i.e. for PRIMARY KEYS, CONST…
Browse files Browse the repository at this point in the history
…RAINT) (#193)

* ✨ impl parseAllColumns

* 🐛 replacing getAllColumns with parseAllColumns

(cherry picked from commit 0c7e33b)

* 🐛 (parseAllColumns)fix for []quoted cases
  • Loading branch information
Shion1305 authored Sep 30, 2024
1 parent b29e7fc commit 02b8e06
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 20 deletions.
32 changes: 12 additions & 20 deletions ddlmod.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,11 @@ var (
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]?(?s:.*?)ON (.*)$`, sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
)

func getAllColumns(s string) []string {
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
columns := make([]string, 0, len(allMatches))
for _, matches := range allMatches {
if len(matches) > 1 {
columns = append(columns, matches[1])
}
}
return columns
}

type ddl struct {
head string
fields []string
Expand Down Expand Up @@ -110,9 +98,10 @@ func parseDDL(strs ...string) (*ddl, error) {
if strings.HasPrefix(fUpper, "CONSTRAINT") {
matches := uniqueRegexp.FindStringSubmatch(f)
if len(matches) > 0 {
if columns := getAllColumns(matches[1]); len(columns) == 1 {
cols, err := parseAllColumns(matches[1])
if err == nil && len(cols) == 1 {
for idx, column := range result.columns {
if column.NameValue.String == columns[0] {
if column.NameValue.String == cols[0] {
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
Expand All @@ -123,12 +112,15 @@ func parseDDL(strs ...string) (*ddl, error) {
continue
}
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
for _, name := range getAllColumns(f) {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
cols, err := parseAllColumns(f)
if err == nil {
for _, name := range cols {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
}
}
}
Expand Down
117 changes: 117 additions & 0 deletions ddlmod_parse_all_columns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package sqlite

import (
"errors"
"fmt"
)

type parseAllColumnsState int

const (
parseAllColumnsState_NONE parseAllColumnsState = iota
parseAllColumnsState_Beginning
parseAllColumnsState_ReadingRawName
parseAllColumnsState_ReadingQuotedName
parseAllColumnsState_EndOfName
parseAllColumnsState_State_End
)

func parseAllColumns(in string) ([]string, error) {
s := []rune(in)
columns := make([]string, 0)
state := parseAllColumnsState_NONE
quote := rune(0)
name := make([]rune, 0)
for i := 0; i < len(s); i++ {
switch state {
case parseAllColumnsState_NONE:
if s[i] == '(' {
state = parseAllColumnsState_Beginning
}
case parseAllColumnsState_Beginning:
if isSpace(s[i]) {
continue
}
if isQuote(s[i]) {
state = parseAllColumnsState_ReadingQuotedName
quote = s[i]
continue
}
if s[i] == '[' {
state = parseAllColumnsState_ReadingQuotedName
quote = ']'
continue
} else if s[i] == ')' {
return columns, fmt.Errorf("unexpected token: %s", string(s[i]))
}
state = parseAllColumnsState_ReadingRawName
name = append(name, s[i])
case parseAllColumnsState_ReadingRawName:
if isSeparator(s[i]) {
state = parseAllColumnsState_Beginning
columns = append(columns, string(name))
name = make([]rune, 0)
continue
}
if s[i] == ')' {
state = parseAllColumnsState_State_End
columns = append(columns, string(name))
}
if isQuote(s[i]) {
return nil, fmt.Errorf("unexpected token: %s", string(s[i]))
}
if isSpace(s[i]) {
state = parseAllColumnsState_EndOfName
columns = append(columns, string(name))
name = make([]rune, 0)
continue
}
name = append(name, s[i])
case parseAllColumnsState_ReadingQuotedName:
if s[i] == quote {
// check if quote character is escaped
if i+1 < len(s) && s[i+1] == quote {
name = append(name, quote)
i++
continue
}
state = parseAllColumnsState_EndOfName
columns = append(columns, string(name))
name = make([]rune, 0)
continue
}
name = append(name, s[i])
case parseAllColumnsState_EndOfName:
if isSpace(s[i]) {
continue
}
if isSeparator(s[i]) {
state = parseAllColumnsState_Beginning
continue
}
if s[i] == ')' {
state = parseAllColumnsState_State_End
continue
}
return nil, fmt.Errorf("unexpected token: %s", string(s[i]))
case parseAllColumnsState_State_End:
break
}
}
if state != parseAllColumnsState_State_End {
return nil, errors.New("unexpected end")
}
return columns, nil
}

func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}

func isQuote(r rune) bool {
return r == '`' || r == '"' || r == '\''
}

func isSeparator(r rune) bool {
return r == ','
}
48 changes: 48 additions & 0 deletions ddlmod_parse_all_columns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package sqlite

import "testing"

func TestParseAllColumns(t *testing.T) {
tc := []struct {
name string
input string
expected []string
}{
{
name: "Simple case",
input: "PRIMARY KEY (column1, column2)",
expected: []string{"column1", "column2"},
},
{
name: "Quoted column name",
input: "PRIMARY KEY (`column,xxx`, \"column 2\", \"column)3\", 'column''4', \"column\"\"5\")",
expected: []string{"column,xxx", "column 2", "column)3", "column'4", "column\"5"},
},
{
name: "Japanese column name",
input: "PRIMARY KEY (カラム1, `カラム2`)",
expected: []string{"カラム1", "カラム2"},
},
{
name: "Column name quoted with []",
input: "PRIMARY KEY ([column1], [column2])",
expected: []string{"column1", "column2"},
},
}
for _, tt := range tc {
t.Run(tt.name, func(t *testing.T) {
cols, err := parseAllColumns(tt.input)
if err != nil {
t.Errorf("Failed to parse columns: %s", err)
}
if len(cols) != len(tt.expected) {
t.Errorf("Expected %d columns, got %d", len(tt.expected), len(cols))
}
for i, col := range cols {
if col != tt.expected[i] {
t.Errorf("Expected %s, got %s", tt.expected[i], col)
}
}
})
}
}

0 comments on commit 02b8e06

Please sign in to comment.