diff --git a/ddlmod.go b/ddlmod.go index 9c93e6a..c839cd7 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -17,23 +17,11 @@ var ( indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]?(?s:.*?)ON (.*)$`, sqliteSeparator, sqliteSeparator)) tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator)) separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator)) - columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator)) columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator)) defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`) regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) ) -func getAllColumns(s string) []string { - allMatches := columnsRegexp.FindAllStringSubmatch(s, -1) - columns := make([]string, 0, len(allMatches)) - for _, matches := range allMatches { - if len(matches) > 1 { - columns = append(columns, matches[1]) - } - } - return columns -} - type ddl struct { head string fields []string @@ -110,9 +98,10 @@ func parseDDL(strs ...string) (*ddl, error) { if strings.HasPrefix(fUpper, "CONSTRAINT") { matches := uniqueRegexp.FindStringSubmatch(f) if len(matches) > 0 { - if columns := getAllColumns(matches[1]); len(columns) == 1 { + cols, err := parseAllColumns(matches[1]) + if err == nil && len(cols) == 1 { for idx, column := range result.columns { - if column.NameValue.String == columns[0] { + if column.NameValue.String == cols[0] { column.UniqueValue = sql.NullBool{Bool: true, Valid: true} result.columns[idx] = column break @@ -123,12 +112,15 @@ func parseDDL(strs ...string) (*ddl, error) { continue } if strings.HasPrefix(fUpper, "PRIMARY KEY") { - for _, name := range getAllColumns(f) { - for idx, column := range result.columns { - if column.NameValue.String == name { - column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} - result.columns[idx] = column - break + cols, err := parseAllColumns(f) + if err == nil { + for _, name := range cols { + for idx, column := range result.columns { + if column.NameValue.String == name { + column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + result.columns[idx] = column + break + } } } } diff --git a/ddlmod_parse_all_columns.go b/ddlmod_parse_all_columns.go new file mode 100644 index 0000000..760acf8 --- /dev/null +++ b/ddlmod_parse_all_columns.go @@ -0,0 +1,117 @@ +package sqlite + +import ( + "errors" + "fmt" +) + +type parseAllColumnsState int + +const ( + parseAllColumnsState_NONE parseAllColumnsState = iota + parseAllColumnsState_Beginning + parseAllColumnsState_ReadingRawName + parseAllColumnsState_ReadingQuotedName + parseAllColumnsState_EndOfName + parseAllColumnsState_State_End +) + +func parseAllColumns(in string) ([]string, error) { + s := []rune(in) + columns := make([]string, 0) + state := parseAllColumnsState_NONE + quote := rune(0) + name := make([]rune, 0) + for i := 0; i < len(s); i++ { + switch state { + case parseAllColumnsState_NONE: + if s[i] == '(' { + state = parseAllColumnsState_Beginning + } + case parseAllColumnsState_Beginning: + if isSpace(s[i]) { + continue + } + if isQuote(s[i]) { + state = parseAllColumnsState_ReadingQuotedName + quote = s[i] + continue + } + if s[i] == '[' { + state = parseAllColumnsState_ReadingQuotedName + quote = ']' + continue + } else if s[i] == ')' { + return columns, fmt.Errorf("unexpected token: %s", string(s[i])) + } + state = parseAllColumnsState_ReadingRawName + name = append(name, s[i]) + case parseAllColumnsState_ReadingRawName: + if isSeparator(s[i]) { + state = parseAllColumnsState_Beginning + columns = append(columns, string(name)) + name = make([]rune, 0) + continue + } + if s[i] == ')' { + state = parseAllColumnsState_State_End + columns = append(columns, string(name)) + } + if isQuote(s[i]) { + return nil, fmt.Errorf("unexpected token: %s", string(s[i])) + } + if isSpace(s[i]) { + state = parseAllColumnsState_EndOfName + columns = append(columns, string(name)) + name = make([]rune, 0) + continue + } + name = append(name, s[i]) + case parseAllColumnsState_ReadingQuotedName: + if s[i] == quote { + // check if quote character is escaped + if i+1 < len(s) && s[i+1] == quote { + name = append(name, quote) + i++ + continue + } + state = parseAllColumnsState_EndOfName + columns = append(columns, string(name)) + name = make([]rune, 0) + continue + } + name = append(name, s[i]) + case parseAllColumnsState_EndOfName: + if isSpace(s[i]) { + continue + } + if isSeparator(s[i]) { + state = parseAllColumnsState_Beginning + continue + } + if s[i] == ')' { + state = parseAllColumnsState_State_End + continue + } + return nil, fmt.Errorf("unexpected token: %s", string(s[i])) + case parseAllColumnsState_State_End: + break + } + } + if state != parseAllColumnsState_State_End { + return nil, errors.New("unexpected end") + } + return columns, nil +} + +func isSpace(r rune) bool { + return r == ' ' || r == '\t' +} + +func isQuote(r rune) bool { + return r == '`' || r == '"' || r == '\'' +} + +func isSeparator(r rune) bool { + return r == ',' +} diff --git a/ddlmod_parse_all_columns_test.go b/ddlmod_parse_all_columns_test.go new file mode 100644 index 0000000..eb70cdd --- /dev/null +++ b/ddlmod_parse_all_columns_test.go @@ -0,0 +1,48 @@ +package sqlite + +import "testing" + +func TestParseAllColumns(t *testing.T) { + tc := []struct { + name string + input string + expected []string + }{ + { + name: "Simple case", + input: "PRIMARY KEY (column1, column2)", + expected: []string{"column1", "column2"}, + }, + { + name: "Quoted column name", + input: "PRIMARY KEY (`column,xxx`, \"column 2\", \"column)3\", 'column''4', \"column\"\"5\")", + expected: []string{"column,xxx", "column 2", "column)3", "column'4", "column\"5"}, + }, + { + name: "Japanese column name", + input: "PRIMARY KEY (カラム1, `カラム2`)", + expected: []string{"カラム1", "カラム2"}, + }, + { + name: "Column name quoted with []", + input: "PRIMARY KEY ([column1], [column2])", + expected: []string{"column1", "column2"}, + }, + } + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + cols, err := parseAllColumns(tt.input) + if err != nil { + t.Errorf("Failed to parse columns: %s", err) + } + if len(cols) != len(tt.expected) { + t.Errorf("Expected %d columns, got %d", len(tt.expected), len(cols)) + } + for i, col := range cols { + if col != tt.expected[i] { + t.Errorf("Expected %s, got %s", tt.expected[i], col) + } + } + }) + } +}