diff --git a/plugin/maxmind/country_csv.go b/plugin/maxmind/country_csv.go index e8f4781fb..cf7219997 100644 --- a/plugin/maxmind/country_csv.go +++ b/plugin/maxmind/country_csv.go @@ -59,6 +59,14 @@ func newGeoLite2CountryCSV(action lib.Action, data json.RawMessage) (lib.InputCo tmp.IPv6File = defaultIPv6File } + // Filter want list + wantList := make(map[string]bool) + for _, want := range tmp.Want { + if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { + wantList[want] = true + } + } + return &geoLite2CountryCSV{ Type: typeCountryCSV, Action: action, @@ -66,7 +74,7 @@ func newGeoLite2CountryCSV(action lib.Action, data json.RawMessage) (lib.InputCo CountryCodeFile: tmp.CountryCodeFile, IPv4File: tmp.IPv4File, IPv6File: tmp.IPv6File, - Want: tmp.Want, + Want: wantList, OnlyIPType: tmp.OnlyIPType, }, nil } @@ -78,7 +86,7 @@ type geoLite2CountryCSV struct { CountryCodeFile string IPv4File string IPv6File string - Want []string + Want map[string]bool OnlyIPType lib.IPType } @@ -171,11 +179,16 @@ func (g *geoLite2CountryCSV) getCountryCode() (map[string]string, error) { } id := strings.TrimSpace(line[0]) - countryCode := strings.TrimSpace(line[4]) + countryCode := strings.ToUpper(strings.TrimSpace(line[4])) if id == "" || countryCode == "" { continue } - ccMap[id] = strings.ToUpper(countryCode) + + if len(g.Want) > 0 && !g.Want[countryCode] { + continue + } + + ccMap[id] = countryCode } if len(ccMap) == 0 { @@ -206,14 +219,6 @@ func (g *geoLite2CountryCSV) process(file string, ccMap map[string]string, entri } defer f.Close() - // Filter want list - wantList := make(map[string]bool) - for _, want := range g.Want { - if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { - wantList[want] = true - } - } - reader := csv.NewReader(f) reader.Read() // skip header @@ -243,9 +248,6 @@ func (g *geoLite2CountryCSV) process(file string, ccMap map[string]string, entri } if countryCode, found := ccMap[ccID]; found { - if len(wantList) > 0 && !wantList[countryCode] { - continue - } cidrStr := strings.ToLower(strings.TrimSpace(record[0])) entry, found := entries[countryCode] if !found { diff --git a/plugin/maxmind/mmdb_in.go b/plugin/maxmind/mmdb_in.go index 3c25931de..8789e0724 100644 --- a/plugin/maxmind/mmdb_in.go +++ b/plugin/maxmind/mmdb_in.go @@ -46,12 +46,20 @@ func newMaxmindMMDBIn(action lib.Action, data json.RawMessage) (lib.InputConvert tmp.URI = defaultMMDBFile } + // Filter want list + wantList := make(map[string]bool) + for _, want := range tmp.Want { + if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { + wantList[want] = true + } + } + return &maxmindMMDBIn{ Type: typeMaxmindMMDBIn, Action: action, Description: descMaxmindMMDBIn, URI: tmp.URI, - Want: tmp.Want, + Want: wantList, OnlyIPType: tmp.OnlyIPType, }, nil } @@ -61,68 +69,55 @@ type maxmindMMDBIn struct { Action lib.Action Description string URI string - Want []string + Want map[string]bool OnlyIPType lib.IPType } -func (g *maxmindMMDBIn) GetType() string { - return g.Type +func (m *maxmindMMDBIn) GetType() string { + return m.Type } -func (g *maxmindMMDBIn) GetAction() lib.Action { - return g.Action +func (m *maxmindMMDBIn) GetAction() lib.Action { + return m.Action } -func (g *maxmindMMDBIn) GetDescription() string { - return g.Description +func (m *maxmindMMDBIn) GetDescription() string { + return m.Description } -func (g *maxmindMMDBIn) Input(container lib.Container) (lib.Container, error) { +func (m *maxmindMMDBIn) Input(container lib.Container) (lib.Container, error) { var content []byte var err error switch { - case strings.HasPrefix(strings.ToLower(g.URI), "http://"), strings.HasPrefix(strings.ToLower(g.URI), "https://"): - content, err = lib.GetRemoteURLContent(g.URI) + case strings.HasPrefix(strings.ToLower(m.URI), "http://"), strings.HasPrefix(strings.ToLower(m.URI), "https://"): + content, err = lib.GetRemoteURLContent(m.URI) default: - content, err = os.ReadFile(g.URI) + content, err = os.ReadFile(m.URI) } if err != nil { return nil, err } entries := make(map[string]*lib.Entry, 300) - err = g.generateEntries(content, entries) + err = m.generateEntries(content, entries) if err != nil { return nil, err } if len(entries) == 0 { - return nil, fmt.Errorf("❌ [type %s | action %s] no entry is generated", typeMaxmindMMDBIn, g.Action) + return nil, fmt.Errorf("❌ [type %s | action %s] no entry is generated", typeMaxmindMMDBIn, m.Action) } var ignoreIPType lib.IgnoreIPOption - switch g.OnlyIPType { + switch m.OnlyIPType { case lib.IPv4: ignoreIPType = lib.IgnoreIPv6 case lib.IPv6: ignoreIPType = lib.IgnoreIPv4 } - // Filter want list - wantList := make(map[string]bool) - for _, want := range g.Want { - if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { - wantList[want] = true - } - } - for _, entry := range entries { - name := entry.GetName() - if len(wantList) > 0 && !wantList[name] { - continue - } - - switch g.Action { + switch m.Action { case lib.ActionAdd: if err := container.Add(entry, ignoreIPType); err != nil { return nil, err @@ -139,7 +134,7 @@ func (g *maxmindMMDBIn) Input(container lib.Container) (lib.Container, error) { return container, nil } -func (g *maxmindMMDBIn) generateEntries(content []byte, entries map[string]*lib.Entry) error { +func (m *maxmindMMDBIn) generateEntries(content []byte, entries map[string]*lib.Entry) error { db, err := maxminddb.FromBytes(content) if err != nil { return err @@ -177,6 +172,10 @@ func (g *maxmindMMDBIn) generateEntries(content []byte, entries map[string]*lib. continue } + if len(m.Want) > 0 && !m.Want[name] { + continue + } + entry, found := entries[name] if !found { entry = lib.NewEntry(name) diff --git a/plugin/plaintext/text_out.go b/plugin/plaintext/text_out.go index a7125249e..91f959085 100644 --- a/plugin/plaintext/text_out.go +++ b/plugin/plaintext/text_out.go @@ -47,12 +47,20 @@ func newTextOut(action lib.Action, data json.RawMessage) (lib.OutputConverter, e tmp.OutputDir = defaultOutputDir } + // Filter want list + wantList := make([]string, 0, len(tmp.Want)) + for _, want := range tmp.Want { + if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { + wantList = append(wantList, want) + } + } + return &textOut{ Type: typeTextOut, Action: action, Description: descTextOut, OutputDir: tmp.OutputDir, - Want: tmp.Want, + Want: wantList, OnlyIPType: tmp.OnlyIPType, }, nil } @@ -79,15 +87,7 @@ func (t *textOut) GetDescription() string { } func (t *textOut) Output(container lib.Container) error { - // Filter want list - wantList := make([]string, 0, len(t.Want)) - for _, want := range t.Want { - if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { - wantList = append(wantList, want) - } - } - - switch len(wantList) { + switch len(t.Want) { case 0: list := make([]string, 0, 300) for entry := range container.Loop() { @@ -115,9 +115,9 @@ func (t *textOut) Output(container lib.Container) error { default: // Sort the list - slices.Sort(wantList) + slices.Sort(t.Want) - for _, name := range wantList { + for _, name := range t.Want { entry, found := container.GetEntry(name) if !found { log.Printf("❌ entry %s not found", name) diff --git a/plugin/special/cutter.go b/plugin/special/cutter.go index a9fce5a1a..1e465d10a 100644 --- a/plugin/special/cutter.go +++ b/plugin/special/cutter.go @@ -38,11 +38,19 @@ func newCutter(action lib.Action, data json.RawMessage) (lib.InputConverter, err return nil, fmt.Errorf("type %s only supports `remove` action", typeCutter) } + // Filter want list + wantList := make(map[string]bool) + for _, want := range tmp.Want { + if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { + wantList[want] = true + } + } + return &cutter{ Type: typeCutter, Action: action, Description: descCutter, - Want: tmp.Want, + Want: wantList, OnlyIPType: tmp.OnlyIPType, }, nil } @@ -51,7 +59,7 @@ type cutter struct { Type string Action lib.Action Description string - Want []string + Want map[string]bool OnlyIPType lib.IPType } @@ -76,17 +84,8 @@ func (c *cutter) Input(container lib.Container) (lib.Container, error) { ignoreIPType = lib.IgnoreIPv4 } - // Filter want list - wantList := make(map[string]bool) - for _, want := range c.Want { - if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { - wantList[want] = true - } - } - for entry := range container.Loop() { - name := entry.GetName() - if len(wantList) > 0 && !wantList[name] { + if len(c.Want) > 0 && !c.Want[entry.GetName()] { continue } if err := container.Remove(entry, lib.CaseRemoveEntry, ignoreIPType); err != nil { diff --git a/plugin/v2ray/dat_in.go b/plugin/v2ray/dat_in.go index f0b0b0f6f..ddf9c9ec7 100644 --- a/plugin/v2ray/dat_in.go +++ b/plugin/v2ray/dat_in.go @@ -45,12 +45,20 @@ func newGeoIPDatIn(action lib.Action, data json.RawMessage) (lib.InputConverter, return nil, fmt.Errorf("[type %s | action %s] uri must be specified in config", typeGeoIPdatIn, action) } + // Filter want list + wantList := make(map[string]bool) + for _, want := range tmp.Want { + if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { + wantList[want] = true + } + } + return &geoIPDatIn{ Type: typeGeoIPdatIn, Action: action, Description: descGeoIPdatIn, URI: tmp.URI, - Want: tmp.Want, + Want: wantList, OnlyIPType: tmp.OnlyIPType, }, nil } @@ -60,7 +68,7 @@ type geoIPDatIn struct { Action lib.Action Description string URI string - Want []string + Want map[string]bool OnlyIPType lib.IPType } @@ -103,20 +111,7 @@ func (g *geoIPDatIn) Input(container lib.Container) (lib.Container, error) { ignoreIPType = lib.IgnoreIPv4 } - // Filter want list - wantList := make(map[string]bool) - for _, want := range g.Want { - if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { - wantList[want] = true - } - } - for _, entry := range entries { - name := entry.GetName() - if len(wantList) > 0 && !wantList[name] { - continue - } - switch g.Action { case lib.ActionAdd: if err := container.Add(entry, ignoreIPType); err != nil { @@ -178,7 +173,12 @@ func (g *geoIPDatIn) generateEntries(reader io.Reader, entries map[string]*lib.E } for _, geoip := range geoipList.Entry { - name := geoip.CountryCode + name := strings.ToUpper(strings.TrimSpace(geoip.CountryCode)) + + if len(g.Want) > 0 && !g.Want[name] { + continue + } + entry, found := entries[name] if !found { entry = lib.NewEntry(name) diff --git a/plugin/v2ray/dat_out.go b/plugin/v2ray/dat_out.go index da921b03d..e7a14079d 100644 --- a/plugin/v2ray/dat_out.go +++ b/plugin/v2ray/dat_out.go @@ -57,13 +57,21 @@ func newGeoIPDat(action lib.Action, data json.RawMessage) (lib.OutputConverter, tmp.OutputDir = defaultOutputDir } + // Filter want list + wantList := make([]string, 0, len(tmp.Want)) + for _, want := range tmp.Want { + if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { + wantList = append(wantList, want) + } + } + return &geoIPDatOut{ Type: typeGeoIPdatOut, Action: action, Description: descGeoIPdatOut, OutputName: tmp.OutputName, OutputDir: tmp.OutputDir, - Want: tmp.Want, + Want: wantList, OneFilePerList: tmp.OneFilePerList, OnlyIPType: tmp.OnlyIPType, }, nil @@ -93,19 +101,11 @@ func (g *geoIPDatOut) GetDescription() string { } func (g *geoIPDatOut) Output(container lib.Container) error { - // Filter want list - wantList := make([]string, 0, len(g.Want)) - for _, want := range g.Want { - if want = strings.ToUpper(strings.TrimSpace(want)); want != "" { - wantList = append(wantList, want) - } - } - geoIPList := new(router.GeoIPList) geoIPList.Entry = make([]*router.GeoIP, 0, 300) updated := false - switch len(wantList) { + switch len(g.Want) { case 0: list := make([]string, 0, 300) for entry := range container.Loop() { @@ -143,9 +143,9 @@ func (g *geoIPDatOut) Output(container lib.Container) error { default: // Sort the list - sort.Strings(wantList) + sort.Strings(g.Want) - for _, name := range wantList { + for _, name := range g.Want { entry, found := container.GetEntry(name) if !found { log.Printf("❌ entry %s not found", name)