Skip to content

Commit

Permalink
add port status caching to http
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed May 8, 2024
1 parent 49ee80c commit 1818f8a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
37 changes: 25 additions & 12 deletions pkg/protocols/common/ports/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,8 @@ func NewPortsCache(dialer *fastdialer.Dialer, size int) *PortsCache {
return p
}

// Do performs a check for open ports
func (p *PortsCache) Do(ctx context.Context, input *contextargs.Context) error {
address := input.MetaInput.Address()
if address == "" {
// assume port is open is given info is not present/enough
return nil
}
// Do performs a check for open ports and caches the result
func (p *PortsCache) Do(ctx context.Context, address string) error {
// check if it exists in cache
if value, err := p.cache.GetIFPresent(address); !errors.Is(err, gcache.KeyNotFoundError) {
switch value {
Expand All @@ -65,11 +60,11 @@ func (p *PortsCache) Do(ctx context.Context, input *contextargs.Context) error {
code, _, _ := p.group.Do(address, func() (interface{}, error) {
conn, err := p.dialer.Dial(ctx, "tcp", address)
if err != nil {
p.cache.Set(address, Closed)
_ = p.cache.Set(address, Closed)
return Closed, nil
}
_ = conn.Close()
p.cache.Set(address, Open)
_ = p.cache.Set(address, Open)
return Open, nil
})

Expand All @@ -81,6 +76,16 @@ func (p *PortsCache) Do(ctx context.Context, input *contextargs.Context) error {
return nil
}

// Do performs a check for open ports
func (p *PortsCache) DoInput(ctx context.Context, input *contextargs.Context) error {
address := input.MetaInput.Address()
if address == "" {
// assume port is open is given info is not present/enough
return nil
}
return p.Do(ctx, address)
}

// Close closes the ports cache and releases any allocated resources
func (p *PortsCache) Close() {
p.cache = nil
Expand All @@ -99,10 +104,18 @@ func Close() {
}
}

// IsPortOpen checks if a port is open or not
func IsPortOpen(input *contextargs.Context) error {
// InputPortStatus checks for cached status of input port
func InputPortStatus(input *contextargs.Context) error {
if portsCacher == nil {
return nil
}
return portsCacher.DoInput(input.Context(), input)
}

// CheckPortStatus checks for cached status of remote port
func CheckPortStatus(ctx context.Context, address string) error {
if portsCacher == nil {
return nil
}
return portsCacher.Do(input.Context(), input)
return portsCacher.Do(ctx, address)
}
10 changes: 10 additions & 0 deletions pkg/protocols/common/protocolstate/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/mapcidr/asn"
"github.com/projectdiscovery/networkpolicy"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/ports"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/nuclei/v3/pkg/utils/expand"
)
Expand Down Expand Up @@ -143,6 +144,15 @@ func Init(options *types.Options) error {
}
Dialer = dialer

// size of the ports cache
portsCacheSize := 1000
if options.BulkSize < 5000 && options.BulkSize > 1000 && options.ScanStrategy != "host-spray" {
// 5000 is acceptable for host-spray
portsCacheSize = options.BulkSize
}
// add dialer to ports-cache
ports.Init(Dialer, portsCacheSize)

// override dialer in mysql
mysql.RegisterDialContext("tcp", func(ctx context.Context, addr string) (net.Conn, error) {
return Dialer.Dial(ctx, "tcp", addr)
Expand Down
13 changes: 12 additions & 1 deletion pkg/protocols/http/httpclientpool/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/ports"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
Expand Down Expand Up @@ -254,8 +255,18 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl

transport := &http.Transport{
ForceAttemptHTTP2: options.ForceAttemptHTTP2,
DialContext: Dialer.Dial,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// use ports-cache
if err := ports.CheckPortStatus(ctx, addr); err != nil {
return nil, errors.Wrapf(err, "failed to connect : %v", addr)
}
return Dialer.Dial(ctx, network, addr)
},
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// use ports-cache
if err := ports.CheckPortStatus(ctx, addr); err != nil {
return nil, errors.Wrapf(err, "failed to connect : %v", addr)
}
if options.TlsImpersonate {
return Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
}
Expand Down

0 comments on commit 1818f8a

Please sign in to comment.