diff --git a/pkg/protocols/common/contextargs/metainput.go b/pkg/protocols/common/contextargs/metainput.go index afda2fda26..62e3194a51 100644 --- a/pkg/protocols/common/contextargs/metainput.go +++ b/pkg/protocols/common/contextargs/metainput.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/md5" "fmt" + "net" "strings" jsoniter "github.com/json-iterator/go" @@ -47,6 +48,41 @@ func (metaInput *MetaInput) URL() (*urlutil.URL, error) { return instance, nil } +// Port returns the port of the target +// if port is not present then empty string is returned +func (metaInput *MetaInput) Port() string { + target, err := urlutil.ParseAbsoluteURL(metaInput.Input, false) + if err != nil { + return "" + } + return target.Port() +} + +// Address return the remote address of target +// Note: it does not resolve the domain to ip +func (metaInput *MetaInput) Address() string { + target, err := urlutil.ParseAbsoluteURL(metaInput.Input, false) + if err != nil { + return "" + } + host := target.Hostname() + port := target.Port() + if metaInput.CustomIP != "" { + host = metaInput.CustomIP + } + if port == "" { + switch target.Scheme { + case urlutil.HTTP: + port = "80" + case urlutil.HTTPS: + port = "443" + default: + port = "80" + } + } + return net.JoinHostPort(host, port) +} + // ID returns a unique id/hash for metainput func (metaInput *MetaInput) ID() string { if metaInput.CustomIP != "" { diff --git a/pkg/protocols/common/ports/ports.go b/pkg/protocols/common/ports/ports.go new file mode 100644 index 0000000000..3abd5d7900 --- /dev/null +++ b/pkg/protocols/common/ports/ports.go @@ -0,0 +1,108 @@ +// ports implements a open-port cache to avoid sending redundant requests to same port +package ports + +import ( + "context" + "errors" + + "github.com/Mzack9999/gcache" + "github.com/projectdiscovery/fastdialer/fastdialer" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" + singleflight "github.com/projectdiscovery/utils/memoize/simpleflight" +) + +var ( + portsCacher *PortsCache + ErrPortClosed = errors.New("port closed or filtered") +) + +type PortStatus uint8 + +const ( + Unknown PortStatus = iota + Open + Closed +) + +// PortsCache is a cache for open ports +type PortsCache struct { + cache gcache.Cache[string, PortStatus] + group singleflight.Group[string] + dialer *fastdialer.Dialer +} + +// NewPortsCache creates a new cache for open ports +func NewPortsCache(dialer *fastdialer.Dialer, size int) *PortsCache { + p := &PortsCache{group: singleflight.Group[string]{}, dialer: dialer} + cache := gcache.New[string, PortStatus](size). + LRU(). + EvictedFunc(func(key string, value PortStatus) { + p.group.Forget(key) + }). + Build() + p.cache = cache + return p +} + +// Do performs a check for open ports +func (p *PortsCache) Do(ctx context.Context, input *contextargs.Context) error { + address := input.MetaInput.Address() + if address == "" { + // assume port is open is given info is not present/enough + return nil + } + // check if it exists in cache + if value, err := p.cache.GetIFPresent(address); !errors.Is(err, gcache.KeyNotFoundError) { + switch value { + case Closed: + return ErrPortClosed + default: + return nil + } + } + + // if not in cache then check if it is open + code, _, _ := p.group.Do(address, func() (interface{}, error) { + conn, err := p.dialer.Dial(ctx, "tcp", address) + if err != nil { + p.cache.Set(address, Closed) + return Closed, nil + } + _ = conn.Close() + p.cache.Set(address, Open) + return Open, nil + }) + + if status, ok := code.(PortStatus); ok { + if status == Closed { + return ErrPortClosed + } + } + return nil +} + +// Close closes the ports cache and releases any allocated resources +func (p *PortsCache) Close() { + p.cache = nil + p.group = singleflight.Group[string]{} +} + +// Init initializes the ports package +func Init(dialer *fastdialer.Dialer, size int) { + portsCacher = NewPortsCache(dialer, size) +} + +// Close closes the ports package +func Close() { + if portsCacher != nil { + portsCacher.Close() + } +} + +// IsPortOpen checks if a port is open or not +func IsPortOpen(input *contextargs.Context) error { + if portsCacher == nil { + return nil + } + return portsCacher.Do(input.Context(), input) +}