From 32dc73ff719fb4ab1dda4070b6f86ea61ba56026 Mon Sep 17 00:00:00 2001
From: Loyalsoldier <10487845+Loyalsoldier@users.noreply.github.com>
Date: Thu, 15 Aug 2024 09:53:46 +0800
Subject: [PATCH] Refine: wantedList in various formats

---
 plugin/maxmind/country_csv.go | 32 ++++++++++---------
 plugin/maxmind/mmdb_in.go     | 59 +++++++++++++++++------------------
 plugin/plaintext/text_out.go  | 24 +++++++-------
 plugin/special/cutter.go      | 23 +++++++-------
 plugin/v2ray/dat_in.go        | 32 +++++++++----------
 plugin/v2ray/dat_out.go       | 24 +++++++-------
 6 files changed, 97 insertions(+), 97 deletions(-)

diff --git a/plugin/maxmind/country_csv.go b/plugin/maxmind/country_csv.go
index e8f4781fb..cf7219997 100644
--- a/plugin/maxmind/country_csv.go
+++ b/plugin/maxmind/country_csv.go
@@ -59,6 +59,14 @@ func newGeoLite2CountryCSV(action lib.Action, data json.RawMessage) (lib.InputCo
 		tmp.IPv6File = defaultIPv6File
 	}
 
+	// Filter want list
+	wantList := make(map[string]bool)
+	for _, want := range tmp.Want {
+		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
+			wantList[want] = true
+		}
+	}
+
 	return &geoLite2CountryCSV{
 		Type:            typeCountryCSV,
 		Action:          action,
@@ -66,7 +74,7 @@ func newGeoLite2CountryCSV(action lib.Action, data json.RawMessage) (lib.InputCo
 		CountryCodeFile: tmp.CountryCodeFile,
 		IPv4File:        tmp.IPv4File,
 		IPv6File:        tmp.IPv6File,
-		Want:            tmp.Want,
+		Want:            wantList,
 		OnlyIPType:      tmp.OnlyIPType,
 	}, nil
 }
@@ -78,7 +86,7 @@ type geoLite2CountryCSV struct {
 	CountryCodeFile string
 	IPv4File        string
 	IPv6File        string
-	Want            []string
+	Want            map[string]bool
 	OnlyIPType      lib.IPType
 }
 
@@ -171,11 +179,16 @@ func (g *geoLite2CountryCSV) getCountryCode() (map[string]string, error) {
 		}
 
 		id := strings.TrimSpace(line[0])
-		countryCode := strings.TrimSpace(line[4])
+		countryCode := strings.ToUpper(strings.TrimSpace(line[4]))
 		if id == "" || countryCode == "" {
 			continue
 		}
-		ccMap[id] = strings.ToUpper(countryCode)
+
+		if len(g.Want) > 0 && !g.Want[countryCode] {
+			continue
+		}
+
+		ccMap[id] = countryCode
 	}
 
 	if len(ccMap) == 0 {
@@ -206,14 +219,6 @@ func (g *geoLite2CountryCSV) process(file string, ccMap map[string]string, entri
 	}
 	defer f.Close()
 
-	// Filter want list
-	wantList := make(map[string]bool)
-	for _, want := range g.Want {
-		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
-			wantList[want] = true
-		}
-	}
-
 	reader := csv.NewReader(f)
 	reader.Read() // skip header
 
@@ -243,9 +248,6 @@ func (g *geoLite2CountryCSV) process(file string, ccMap map[string]string, entri
 		}
 
 		if countryCode, found := ccMap[ccID]; found {
-			if len(wantList) > 0 && !wantList[countryCode] {
-				continue
-			}
 			cidrStr := strings.ToLower(strings.TrimSpace(record[0]))
 			entry, found := entries[countryCode]
 			if !found {
diff --git a/plugin/maxmind/mmdb_in.go b/plugin/maxmind/mmdb_in.go
index 3c25931de..8789e0724 100644
--- a/plugin/maxmind/mmdb_in.go
+++ b/plugin/maxmind/mmdb_in.go
@@ -46,12 +46,20 @@ func newMaxmindMMDBIn(action lib.Action, data json.RawMessage) (lib.InputConvert
 		tmp.URI = defaultMMDBFile
 	}
 
+	// Filter want list
+	wantList := make(map[string]bool)
+	for _, want := range tmp.Want {
+		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
+			wantList[want] = true
+		}
+	}
+
 	return &maxmindMMDBIn{
 		Type:        typeMaxmindMMDBIn,
 		Action:      action,
 		Description: descMaxmindMMDBIn,
 		URI:         tmp.URI,
-		Want:        tmp.Want,
+		Want:        wantList,
 		OnlyIPType:  tmp.OnlyIPType,
 	}, nil
 }
@@ -61,68 +69,55 @@ type maxmindMMDBIn struct {
 	Action      lib.Action
 	Description string
 	URI         string
-	Want        []string
+	Want        map[string]bool
 	OnlyIPType  lib.IPType
 }
 
-func (g *maxmindMMDBIn) GetType() string {
-	return g.Type
+func (m *maxmindMMDBIn) GetType() string {
+	return m.Type
 }
 
-func (g *maxmindMMDBIn) GetAction() lib.Action {
-	return g.Action
+func (m *maxmindMMDBIn) GetAction() lib.Action {
+	return m.Action
 }
 
-func (g *maxmindMMDBIn) GetDescription() string {
-	return g.Description
+func (m *maxmindMMDBIn) GetDescription() string {
+	return m.Description
 }
 
-func (g *maxmindMMDBIn) Input(container lib.Container) (lib.Container, error) {
+func (m *maxmindMMDBIn) Input(container lib.Container) (lib.Container, error) {
 	var content []byte
 	var err error
 	switch {
-	case strings.HasPrefix(strings.ToLower(g.URI), "http://"), strings.HasPrefix(strings.ToLower(g.URI), "https://"):
-		content, err = lib.GetRemoteURLContent(g.URI)
+	case strings.HasPrefix(strings.ToLower(m.URI), "http://"), strings.HasPrefix(strings.ToLower(m.URI), "https://"):
+		content, err = lib.GetRemoteURLContent(m.URI)
 	default:
-		content, err = os.ReadFile(g.URI)
+		content, err = os.ReadFile(m.URI)
 	}
 	if err != nil {
 		return nil, err
 	}
 
 	entries := make(map[string]*lib.Entry, 300)
-	err = g.generateEntries(content, entries)
+	err = m.generateEntries(content, entries)
 	if err != nil {
 		return nil, err
 	}
 
 	if len(entries) == 0 {
-		return nil, fmt.Errorf("❌ [type %s | action %s] no entry is generated", typeMaxmindMMDBIn, g.Action)
+		return nil, fmt.Errorf("❌ [type %s | action %s] no entry is generated", typeMaxmindMMDBIn, m.Action)
 	}
 
 	var ignoreIPType lib.IgnoreIPOption
-	switch g.OnlyIPType {
+	switch m.OnlyIPType {
 	case lib.IPv4:
 		ignoreIPType = lib.IgnoreIPv6
 	case lib.IPv6:
 		ignoreIPType = lib.IgnoreIPv4
 	}
 
-	// Filter want list
-	wantList := make(map[string]bool)
-	for _, want := range g.Want {
-		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
-			wantList[want] = true
-		}
-	}
-
 	for _, entry := range entries {
-		name := entry.GetName()
-		if len(wantList) > 0 && !wantList[name] {
-			continue
-		}
-
-		switch g.Action {
+		switch m.Action {
 		case lib.ActionAdd:
 			if err := container.Add(entry, ignoreIPType); err != nil {
 				return nil, err
@@ -139,7 +134,7 @@ func (g *maxmindMMDBIn) Input(container lib.Container) (lib.Container, error) {
 	return container, nil
 }
 
-func (g *maxmindMMDBIn) generateEntries(content []byte, entries map[string]*lib.Entry) error {
+func (m *maxmindMMDBIn) generateEntries(content []byte, entries map[string]*lib.Entry) error {
 	db, err := maxminddb.FromBytes(content)
 	if err != nil {
 		return err
@@ -177,6 +172,10 @@ func (g *maxmindMMDBIn) generateEntries(content []byte, entries map[string]*lib.
 			continue
 		}
 
+		if len(m.Want) > 0 && !m.Want[name] {
+			continue
+		}
+
 		entry, found := entries[name]
 		if !found {
 			entry = lib.NewEntry(name)
diff --git a/plugin/plaintext/text_out.go b/plugin/plaintext/text_out.go
index a7125249e..91f959085 100644
--- a/plugin/plaintext/text_out.go
+++ b/plugin/plaintext/text_out.go
@@ -47,12 +47,20 @@ func newTextOut(action lib.Action, data json.RawMessage) (lib.OutputConverter, e
 		tmp.OutputDir = defaultOutputDir
 	}
 
+	// Filter want list
+	wantList := make([]string, 0, len(tmp.Want))
+	for _, want := range tmp.Want {
+		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
+			wantList = append(wantList, want)
+		}
+	}
+
 	return &textOut{
 		Type:        typeTextOut,
 		Action:      action,
 		Description: descTextOut,
 		OutputDir:   tmp.OutputDir,
-		Want:        tmp.Want,
+		Want:        wantList,
 		OnlyIPType:  tmp.OnlyIPType,
 	}, nil
 }
@@ -79,15 +87,7 @@ func (t *textOut) GetDescription() string {
 }
 
 func (t *textOut) Output(container lib.Container) error {
-	// Filter want list
-	wantList := make([]string, 0, len(t.Want))
-	for _, want := range t.Want {
-		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
-			wantList = append(wantList, want)
-		}
-	}
-
-	switch len(wantList) {
+	switch len(t.Want) {
 	case 0:
 		list := make([]string, 0, 300)
 		for entry := range container.Loop() {
@@ -115,9 +115,9 @@ func (t *textOut) Output(container lib.Container) error {
 
 	default:
 		// Sort the list
-		slices.Sort(wantList)
+		slices.Sort(t.Want)
 
-		for _, name := range wantList {
+		for _, name := range t.Want {
 			entry, found := container.GetEntry(name)
 			if !found {
 				log.Printf("❌ entry %s not found", name)
diff --git a/plugin/special/cutter.go b/plugin/special/cutter.go
index a9fce5a1a..1e465d10a 100644
--- a/plugin/special/cutter.go
+++ b/plugin/special/cutter.go
@@ -38,11 +38,19 @@ func newCutter(action lib.Action, data json.RawMessage) (lib.InputConverter, err
 		return nil, fmt.Errorf("type %s only supports `remove` action", typeCutter)
 	}
 
+	// Filter want list
+	wantList := make(map[string]bool)
+	for _, want := range tmp.Want {
+		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
+			wantList[want] = true
+		}
+	}
+
 	return &cutter{
 		Type:        typeCutter,
 		Action:      action,
 		Description: descCutter,
-		Want:        tmp.Want,
+		Want:        wantList,
 		OnlyIPType:  tmp.OnlyIPType,
 	}, nil
 }
@@ -51,7 +59,7 @@ type cutter struct {
 	Type        string
 	Action      lib.Action
 	Description string
-	Want        []string
+	Want        map[string]bool
 	OnlyIPType  lib.IPType
 }
 
@@ -76,17 +84,8 @@ func (c *cutter) Input(container lib.Container) (lib.Container, error) {
 		ignoreIPType = lib.IgnoreIPv4
 	}
 
-	// Filter want list
-	wantList := make(map[string]bool)
-	for _, want := range c.Want {
-		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
-			wantList[want] = true
-		}
-	}
-
 	for entry := range container.Loop() {
-		name := entry.GetName()
-		if len(wantList) > 0 && !wantList[name] {
+		if len(c.Want) > 0 && !c.Want[entry.GetName()] {
 			continue
 		}
 		if err := container.Remove(entry, lib.CaseRemoveEntry, ignoreIPType); err != nil {
diff --git a/plugin/v2ray/dat_in.go b/plugin/v2ray/dat_in.go
index f0b0b0f6f..ddf9c9ec7 100644
--- a/plugin/v2ray/dat_in.go
+++ b/plugin/v2ray/dat_in.go
@@ -45,12 +45,20 @@ func newGeoIPDatIn(action lib.Action, data json.RawMessage) (lib.InputConverter,
 		return nil, fmt.Errorf("[type %s | action %s] uri must be specified in config", typeGeoIPdatIn, action)
 	}
 
+	// Filter want list
+	wantList := make(map[string]bool)
+	for _, want := range tmp.Want {
+		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
+			wantList[want] = true
+		}
+	}
+
 	return &geoIPDatIn{
 		Type:        typeGeoIPdatIn,
 		Action:      action,
 		Description: descGeoIPdatIn,
 		URI:         tmp.URI,
-		Want:        tmp.Want,
+		Want:        wantList,
 		OnlyIPType:  tmp.OnlyIPType,
 	}, nil
 }
@@ -60,7 +68,7 @@ type geoIPDatIn struct {
 	Action      lib.Action
 	Description string
 	URI         string
-	Want        []string
+	Want        map[string]bool
 	OnlyIPType  lib.IPType
 }
 
@@ -103,20 +111,7 @@ func (g *geoIPDatIn) Input(container lib.Container) (lib.Container, error) {
 		ignoreIPType = lib.IgnoreIPv4
 	}
 
-	// Filter want list
-	wantList := make(map[string]bool)
-	for _, want := range g.Want {
-		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
-			wantList[want] = true
-		}
-	}
-
 	for _, entry := range entries {
-		name := entry.GetName()
-		if len(wantList) > 0 && !wantList[name] {
-			continue
-		}
-
 		switch g.Action {
 		case lib.ActionAdd:
 			if err := container.Add(entry, ignoreIPType); err != nil {
@@ -178,7 +173,12 @@ func (g *geoIPDatIn) generateEntries(reader io.Reader, entries map[string]*lib.E
 	}
 
 	for _, geoip := range geoipList.Entry {
-		name := geoip.CountryCode
+		name := strings.ToUpper(strings.TrimSpace(geoip.CountryCode))
+
+		if len(g.Want) > 0 && !g.Want[name] {
+			continue
+		}
+
 		entry, found := entries[name]
 		if !found {
 			entry = lib.NewEntry(name)
diff --git a/plugin/v2ray/dat_out.go b/plugin/v2ray/dat_out.go
index da921b03d..e7a14079d 100644
--- a/plugin/v2ray/dat_out.go
+++ b/plugin/v2ray/dat_out.go
@@ -57,13 +57,21 @@ func newGeoIPDat(action lib.Action, data json.RawMessage) (lib.OutputConverter,
 		tmp.OutputDir = defaultOutputDir
 	}
 
+	// Filter want list
+	wantList := make([]string, 0, len(tmp.Want))
+	for _, want := range tmp.Want {
+		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
+			wantList = append(wantList, want)
+		}
+	}
+
 	return &geoIPDatOut{
 		Type:           typeGeoIPdatOut,
 		Action:         action,
 		Description:    descGeoIPdatOut,
 		OutputName:     tmp.OutputName,
 		OutputDir:      tmp.OutputDir,
-		Want:           tmp.Want,
+		Want:           wantList,
 		OneFilePerList: tmp.OneFilePerList,
 		OnlyIPType:     tmp.OnlyIPType,
 	}, nil
@@ -93,19 +101,11 @@ func (g *geoIPDatOut) GetDescription() string {
 }
 
 func (g *geoIPDatOut) Output(container lib.Container) error {
-	// Filter want list
-	wantList := make([]string, 0, len(g.Want))
-	for _, want := range g.Want {
-		if want = strings.ToUpper(strings.TrimSpace(want)); want != "" {
-			wantList = append(wantList, want)
-		}
-	}
-
 	geoIPList := new(router.GeoIPList)
 	geoIPList.Entry = make([]*router.GeoIP, 0, 300)
 	updated := false
 
-	switch len(wantList) {
+	switch len(g.Want) {
 	case 0:
 		list := make([]string, 0, 300)
 		for entry := range container.Loop() {
@@ -143,9 +143,9 @@ func (g *geoIPDatOut) Output(container lib.Container) error {
 
 	default:
 		// Sort the list
-		sort.Strings(wantList)
+		sort.Strings(g.Want)
 
-		for _, name := range wantList {
+		for _, name := range g.Want {
 			entry, found := container.GetEntry(name)
 			if !found {
 				log.Printf("❌ entry %s not found", name)