Skip to content

Commit

Permalink
feat: add port status caching
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed May 8, 2024
1 parent 997c5d1 commit 49ee80c
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 0 deletions.
36 changes: 36 additions & 0 deletions pkg/protocols/common/contextargs/metainput.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto/md5"
"fmt"
"net"
"strings"

jsoniter "github.com/json-iterator/go"
Expand Down Expand Up @@ -47,6 +48,41 @@ func (metaInput *MetaInput) URL() (*urlutil.URL, error) {
return instance, nil
}

// Port returns the port of the target
// if port is not present then empty string is returned
func (metaInput *MetaInput) Port() string {
target, err := urlutil.ParseAbsoluteURL(metaInput.Input, false)
if err != nil {
return ""
}
return target.Port()
}

// Address return the remote address of target
// Note: it does not resolve the domain to ip
func (metaInput *MetaInput) Address() string {
target, err := urlutil.ParseAbsoluteURL(metaInput.Input, false)
if err != nil {
return ""
}
host := target.Hostname()
port := target.Port()
if metaInput.CustomIP != "" {
host = metaInput.CustomIP
}
if port == "" {
switch target.Scheme {
case urlutil.HTTP:
port = "80"
case urlutil.HTTPS:
port = "443"
default:
port = "80"
}
}
return net.JoinHostPort(host, port)
}

// ID returns a unique id/hash for metainput
func (metaInput *MetaInput) ID() string {
if metaInput.CustomIP != "" {
Expand Down
108 changes: 108 additions & 0 deletions pkg/protocols/common/ports/ports.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// ports implements a open-port cache to avoid sending redundant requests to same port
package ports

import (
"context"
"errors"

"github.com/Mzack9999/gcache"
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
singleflight "github.com/projectdiscovery/utils/memoize/simpleflight"
)

var (
portsCacher *PortsCache
ErrPortClosed = errors.New("port closed or filtered")
)

type PortStatus uint8

const (
Unknown PortStatus = iota
Open
Closed
)

// PortsCache is a cache for open ports
type PortsCache struct {
cache gcache.Cache[string, PortStatus]
group singleflight.Group[string]
dialer *fastdialer.Dialer
}

// NewPortsCache creates a new cache for open ports
func NewPortsCache(dialer *fastdialer.Dialer, size int) *PortsCache {
p := &PortsCache{group: singleflight.Group[string]{}, dialer: dialer}
cache := gcache.New[string, PortStatus](size).
LRU().
EvictedFunc(func(key string, value PortStatus) {
p.group.Forget(key)
}).
Build()
p.cache = cache
return p
}

// Do performs a check for open ports
func (p *PortsCache) Do(ctx context.Context, input *contextargs.Context) error {
address := input.MetaInput.Address()
if address == "" {
// assume port is open is given info is not present/enough
return nil
}
// check if it exists in cache
if value, err := p.cache.GetIFPresent(address); !errors.Is(err, gcache.KeyNotFoundError) {
switch value {
case Closed:
return ErrPortClosed
default:
return nil
}
}

// if not in cache then check if it is open
code, _, _ := p.group.Do(address, func() (interface{}, error) {
conn, err := p.dialer.Dial(ctx, "tcp", address)
if err != nil {
p.cache.Set(address, Closed)
return Closed, nil
}
_ = conn.Close()
p.cache.Set(address, Open)
return Open, nil
})

if status, ok := code.(PortStatus); ok {
if status == Closed {
return ErrPortClosed
}
}
return nil
}

// Close closes the ports cache and releases any allocated resources
func (p *PortsCache) Close() {
p.cache = nil
p.group = singleflight.Group[string]{}
}

// Init initializes the ports package
func Init(dialer *fastdialer.Dialer, size int) {
portsCacher = NewPortsCache(dialer, size)
}

// Close closes the ports package
func Close() {
if portsCacher != nil {
portsCacher.Close()
}
}

// IsPortOpen checks if a port is open or not
func IsPortOpen(input *contextargs.Context) error {
if portsCacher == nil {
return nil
}
return portsCacher.Do(input.Context(), input)
}

0 comments on commit 49ee80c

Please sign in to comment.