diff --git a/pkg/hosts/hosts.go b/pkg/hosts/hosts.go index 7f05598d7..061314152 100644 --- a/pkg/hosts/hosts.go +++ b/pkg/hosts/hosts.go @@ -40,6 +40,10 @@ func NewHosts(m domain.Matcher[*IPs]) *Hosts { } } +func (h *Hosts) GetMatcher() domain.Matcher[*IPs] { + return h.matcher +} + func (h *Hosts) Lookup(fqdn string) (ipv4, ipv6 []netip.Addr) { ips, ok := h.matcher.Match(fqdn) if !ok { diff --git a/plugin/executable/hosts/hosts.go b/plugin/executable/hosts/hosts.go index de405bb86..7cf9a0061 100644 --- a/plugin/executable/hosts/hosts.go +++ b/plugin/executable/hosts/hosts.go @@ -22,13 +22,17 @@ package hosts import ( "bytes" "context" + "encoding/json" "fmt" "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/hosts" "github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/go-chi/chi/v5" "github.com/miekg/dns" + "go.uber.org/zap" + "net/http" "os" ) @@ -49,8 +53,10 @@ type Hosts struct { h *hosts.Hosts } -func Init(_ *coremain.BP, args any) (any, error) { - return NewHosts(args.(*Args)) +func Init(bp *coremain.BP, args any) (any, error) { + h, err := NewHosts(args.(*Args)) + bp.RegAPI(h.Api(bp.L())) + return h, err } func NewHosts(args *Args) (*Hosts, error) { @@ -87,3 +93,35 @@ func (h *Hosts) Exec(_ context.Context, qCtx *query_context.Context) error { } return nil } + +func (h *Hosts) Api(logger *zap.Logger) *chi.Mux { + router := chi.NewRouter() + router.Post("/update", func(writer http.ResponseWriter, request *http.Request) { + b := request.Body + payload := map[string]interface{}{} + payload["code"] = -1 + if err := domain.LoadFromTextReader[(*hosts.IPs)](h.h.GetMatcher().(*domain.MixMatcher[*hosts.IPs]), b, hosts.ParseIPs); err != nil { + payload["msg"] = err.Error() + if err := respondWithJSON(writer, http.StatusOK, payload); err != nil { + logger.Error("fail to response hosts update", zap.Error(err)) + } + return + } + payload["msg"] = "ok" + payload["code"] = 0 + if err := respondWithJSON(writer, http.StatusOK, payload); err != nil { + logger.Error("fail to response hosts update", zap.Error(err)) + } + }) + return router +} + +func respondWithJSON(w http.ResponseWriter, code int, payload interface{}) error { + response, _ := json.Marshal(payload) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + if _, err := w.Write(response); err != nil { + return err + } + return nil +} diff --git a/plugin/executable/ipset/ipset.go b/plugin/executable/ipset/ipset.go index f8d609b57..a6112744a 100644 --- a/plugin/executable/ipset/ipset.go +++ b/plugin/executable/ipset/ipset.go @@ -35,8 +35,10 @@ func init() { type Args struct { SetName4 string `yaml:"set_name4"` SetName6 string `yaml:"set_name6"` - Mask4 int `yaml:"mask4"` // default 24 - Mask6 int `yaml:"mask6"` // default 32 + Mask4 int `yaml:"mask4"` // default 24 + Mask6 int `yaml:"mask6"` // default 32 + Timeout4 int `yaml:"timeout4"` // default -1, not use + Timeout6 int `yaml:"timeout6"` // default -1, not use } var _ sequence.Executable = (*ipSetPlugin)(nil) @@ -52,21 +54,30 @@ func QuickSetup(_ sequence.BQ, s string) (any, error) { args := new(Args) for _, argsStr := range fs { ss := strings.Split(argsStr, ",") - if len(ss) != 3 { - return nil, fmt.Errorf("invalid args, expect 5 fields, got %d", len(ss)) + if len(ss) != 3 && len(ss) != 4 { + return nil, fmt.Errorf("invalid args, expect 3 or 4 fields, got %d", len(ss)) } m, err := strconv.Atoi(ss[2]) if err != nil { return nil, fmt.Errorf("invalid mask, %w", err) } + ttl := -1 + if len(ss) == 4 { + if ttl, err = strconv.Atoi(ss[3]); err != nil { + return nil, fmt.Errorf("invalid timeout, %w", err) + } + } + switch ss[1] { case "inet": args.Mask4 = m args.SetName4 = ss[0] + args.Timeout4 = ttl case "inet6": args.Mask6 = m args.SetName6 = ss[0] + args.Timeout6 = ttl default: return nil, fmt.Errorf("invalid set family, %s", ss[0]) } diff --git a/plugin/executable/ipset/ipset_linux.go b/plugin/executable/ipset/ipset_linux.go index 18812d98d..bad52101b 100644 --- a/plugin/executable/ipset/ipset_linux.go +++ b/plugin/executable/ipset/ipset_linux.go @@ -54,6 +54,13 @@ func newIpSetPlugin(args *Args) (*ipSetPlugin, error) { }, nil } +func addIpSet(nl *ipset.NetLink, setName string, prefix netip.Prefix, timeout int) error { + if timeout == -1 { + return ipset.AddPrefix(nl, setName, prefix) + } + return ipset.AddPrefix(nl, setName, prefix, ipset.OptTimeout(uint32(timeout))) +} + func (p *ipSetPlugin) Exec(_ context.Context, qCtx *query_context.Context) error { r := qCtx.R() if r != nil { @@ -79,7 +86,7 @@ func (p *ipSetPlugin) addIPSet(r *dns.Msg) error { if !ok { return fmt.Errorf("invalid A record with ip: %s", rr.A) } - if err := ipset.AddPrefix(p.nl, p.args.SetName4, netip.PrefixFrom(addr, p.args.Mask4)); err != nil { + if err := addIpSet(p.nl, p.args.SetName4, netip.PrefixFrom(addr, p.args.Mask4), p.args.Timeout4); err != nil { return err } @@ -91,7 +98,7 @@ func (p *ipSetPlugin) addIPSet(r *dns.Msg) error { if !ok { return fmt.Errorf("invalid AAAA record with ip: %s", rr.AAAA) } - if err := ipset.AddPrefix(p.nl, p.args.SetName6, netip.PrefixFrom(addr, p.args.Mask6)); err != nil { + if err := addIpSet(p.nl, p.args.SetName6, netip.PrefixFrom(addr, p.args.Mask6), p.args.Timeout6); err != nil { return err } default: