Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hosterrorscache): add Remove and MarkFailedOrRemove methods #5984

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
6 changes: 3 additions & 3 deletions pkg/core/workflow_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
firstMatched = true
}
}
if w.Options.HostErrorsCache != nil {
w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err)
}
if err != nil {
if w.Options.HostErrorsCache != nil {
w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err)
}
if len(template.Executers) == 1 {
mainErr = err
} else {
Expand Down
124 changes: 94 additions & 30 deletions pkg/protocols/common/hosterrorscache/hosterrorscache.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hosterrorscache

import (
"errors"
"net"
"net/url"
"regexp"
Expand All @@ -20,10 +21,12 @@ import (
// CacheInterface defines the signature of the hosterrorscache so that
// users of Nuclei as embedded lib may implement their own cache
type CacheInterface interface {
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
Remove(ctx *contextargs.Context) // remove a host from the cache
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
MarkFailedOrRemove(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host or remove it
}

var (
Expand All @@ -47,16 +50,20 @@ type cacheItem struct {
errors atomic.Int32
isPermanentErr bool
cause error // optional cause
mu sync.Mutex
}

const DefaultMaxHostsCount = 10000

// New returns a new host max errors cache
func New(maxHostError, maxHostsCount int, trackError []string) *Cache {
gc := gcache.New[string, *cacheItem](maxHostsCount).
ARC().
Build()
return &Cache{failedTargets: gc, MaxHostError: maxHostError, TrackError: trackError}
gc := gcache.New[string, *cacheItem](maxHostsCount).ARC().Build()

return &Cache{
failedTargets: gc,
MaxHostError: maxHostError,
TrackError: trackError,
}
}

// SetVerbose sets the cache to log at verbose level
Expand Down Expand Up @@ -118,47 +125,104 @@ func (c *Cache) NormalizeCacheValue(value string) string {
func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool {
finalValue := c.GetKeyFromContext(ctx, nil)

existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
cache, err := c.failedTargets.GetIFPresent(finalValue)
if err != nil {
return false
}
if existingCacheItem.isPermanentErr {

cache.mu.Lock()
defer cache.mu.Unlock()

if cache.isPermanentErr {
// skipping permanent errors is expected so verbose instead of info
gologger.Verbose().Msgf("Skipped %s from target list as found unresponsive permanently: %s", finalValue, existingCacheItem.cause)
gologger.Verbose().Msgf("Skipped %s from target list as found unresponsive permanently: %s", finalValue, cache.cause)
return true
}

if existingCacheItem.errors.Load() >= int32(c.MaxHostError) {
existingCacheItem.Do(func() {
gologger.Info().Msgf("Skipped %s from target list as found unresponsive %d times", finalValue, existingCacheItem.errors.Load())
if cache.errors.Load() >= int32(c.MaxHostError) {
cache.Do(func() {
gologger.Info().Msgf("Skipped %s from target list as found unresponsive %d times", finalValue, cache.errors.Load())
})
return true
}

return false
}

// Remove removes a host from the cache
func (c *Cache) Remove(ctx *contextargs.Context) {
key := c.GetKeyFromContext(ctx, nil)
_ = c.failedTargets.Remove(key) // remove even the cache is not present
}

// MarkFailed marks a host as failed previously
//
// Deprecated: Use MarkFailedOrRemove instead.
func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) {
if !c.checkError(protoType, err) {
c.MarkFailedOrRemove(protoType, ctx, err)
}

// MarkFailedOrRemove marks a host as failed previously or removes it
func (c *Cache) MarkFailedOrRemove(protoType string, ctx *contextargs.Context, err error) {
if err != nil && !c.checkError(protoType, err) {
return
}
finalValue := c.GetKeyFromContext(ctx, err)
existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
if err != nil || existingCacheItem == nil {
newItem := &cacheItem{errors: atomic.Int32{}}
newItem.errors.Store(1)
if errkit.IsKind(err, errkit.ErrKindNetworkPermanent) {
// skip this address altogether
// permanent errors are always permanent hence this is created once
// and never updated so no need to synchronize
newItem.isPermanentErr = true
newItem.cause = err
}
_ = c.failedTargets.Set(finalValue, newItem)

if err == nil {
// Remove the host from cache
//
// NOTE(dwisiswant0): The decision was made to completely remove the
// cached entry for the host instead of simply decrementing the error
// count (using `(atomic.Int32).Swap` to update the value to `N-1`).
// This approach was chosen because the error handling logic operates
// concurrently, and decrementing the count could lead to UB (unexpected
// behavior) even when the error is `nil`.
//
// To clarify, consider the following scenario where the error
// encountered does NOT belong to the permanent network error category
// (`errkit.ErrKindNetworkPermanent`):
//
// 1. Iteration 1: A timeout error occurs, and the error count for the
// host is incremented.
// 2. Iteration 2: Another timeout error is encountered, leading to
// another increment in the host's error count.
// 3. Iteration 3: A third timeout error happens, which increments the
// error count further. At this point, the host is flagged as
// unresponsive.
// 4. Iteration 4: The host becomes reachable (no error or a transient
// issue resolved). Instead of performing a no-op and leaving the
// host in the cache, the host entry is removed entirely to reset its
// state.
// 5. Iteration 5: A subsequent timeout error occurs after the host was
// removed and re-added to the cache. The error count is reset and
// starts from 1 again.
//
// This removal strategy ensures the cache is updated dynamically to
// reflect the current state of the host without persisting stale or
// irrelevant error counts that could interfere with future error
// handling and tracking logic.
c.Remove(ctx)

return
}
existingCacheItem.errors.Add(1)
_ = c.failedTargets.Set(finalValue, existingCacheItem)

cacheKey := c.GetKeyFromContext(ctx, err)
cache, cacheErr := c.failedTargets.GetIFPresent(cacheKey)
if errors.Is(cacheErr, gcache.KeyNotFoundError) {
cache = &cacheItem{errors: atomic.Int32{}}
}

cache.mu.Lock()
defer cache.mu.Unlock()

if errkit.IsKind(err, errkit.ErrKindNetworkPermanent) {
cache.isPermanentErr = true
}

cache.cause = err
cache.errors.Add(1)

_ = c.failedTargets.Set(cacheKey, cache)
}

// GetKeyFromContext returns the key for the cache from the context
Expand Down
78 changes: 62 additions & 16 deletions pkg/protocols/common/hosterrorscache/hosterrorscache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package hosterrorscache

import (
"context"
"fmt"
"errors"
"sync"
"sync/atomic"
"testing"
Expand All @@ -17,28 +17,40 @@ const (

func TestCacheCheck(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
err := errors.New("net/http: timeout awaiting response headers")

t.Run("increment host error", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
for i := 1; i < 3; i++ {
cache.MarkFailed(protoType, ctx, err)
got := cache.Check(protoType, ctx)
require.Falsef(t, got, "got %v in iteration %d", got, i)
}
})

for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(protoType, newCtxArgs("test"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
} else {
// above 3 it must remain flagged to skip
require.True(t, got)
t.Run("flagged", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
for i := 1; i <= 3; i++ {
cache.MarkFailed(protoType, ctx, err)
}
}

value := cache.Check(protoType, newCtxArgs("test"))
require.Equal(t, true, value, "could not get checked value")
got := cache.Check(protoType, ctx)
require.True(t, got)
})

t.Run("mark failed or remove", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
cache.MarkFailedOrRemove(protoType, ctx, nil) // nil error should remove the host from cache
got := cache.Check(protoType, ctx)
require.False(t, got)
})
}

func TestTrackErrors(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})

for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
cache.MarkFailed(protoType, newCtxArgs("custom"), errors.New("got: nested: custom error"))
got := cache.Check(protoType, newCtxArgs("custom"))
if i < 2 {
// till 3 the host is not flagged to skip
Expand Down Expand Up @@ -74,6 +86,20 @@ func TestCacheItemDo(t *testing.T) {
require.Equal(t, count, 1)
}

func TestRemove(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
ctx := newCtxArgs(t.Name())
err := errors.New("net/http: timeout awaiting response headers")

for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, ctx, err)
}

require.True(t, cache.Check(protoType, ctx))
cache.Remove(ctx)
require.False(t, cache.Check(protoType, ctx))
}

func TestCacheMarkFailed(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)

Expand All @@ -90,7 +116,7 @@ func TestCacheMarkFailed(t *testing.T) {

for _, test := range tests {
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host"))
cache.MarkFailed(protoType, newCtxArgs(test.host), errors.New("no address found for host"))
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
require.Nil(t, err)
require.NotNil(t, failedTarget)
Expand Down Expand Up @@ -126,7 +152,7 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), errors.New("net/http: timeout awaiting response headers"))
}()
}
}
Expand All @@ -144,6 +170,26 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
}
}

func TestCacheCheckConcurrent(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
ctx := newCtxArgs(t.Name())

wg := sync.WaitGroup{}
for i := 1; i <= 100; i++ {
wg.Add(1)
i := i
go func() {
defer wg.Done()
cache.MarkFailed(protoType, ctx, errors.New("no address found for host"))
if i >= 3 {
got := cache.Check(protoType, ctx)
require.True(t, got)
}
}()
}
wg.Wait()
}

func newCtxArgs(value string) *contextargs.Context {
ctx := contextargs.NewWithInput(context.TODO(), value)
return ctx
Expand Down
24 changes: 6 additions & 18 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,8 @@ func (request *Request) executeRaceRequest(input *contextargs.Context, previous

// look for unresponsive hosts and cancel inflight requests as well
spmHandler.SetOnResultCallback(func(err error) {
if err == nil {
return
}
// marks thsi host as unresponsive if applicable
request.markUnresponsiveAddress(input, err)
request.markHostError(input, err)
dwisiswant0 marked this conversation as resolved.
Show resolved Hide resolved
if request.isUnresponsiveAddress(input) {
// stop all inflight requests
spmHandler.Cancel()
Expand Down Expand Up @@ -234,11 +231,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV

// look for unresponsive hosts and cancel inflight requests as well
spmHandler.SetOnResultCallback(func(err error) {
if err == nil {
return
}
// marks thsi host as unresponsive if applicable
request.markUnresponsiveAddress(input, err)
request.markHostError(input, err)
if request.isUnresponsiveAddress(input) {
// stop all inflight requests
spmHandler.Cancel()
Expand Down Expand Up @@ -378,11 +372,8 @@ func (request *Request) executeTurboHTTP(input *contextargs.Context, dynamicValu

// look for unresponsive hosts and cancel inflight requests as well
spmHandler.SetOnResultCallback(func(err error) {
if err == nil {
return
}
// marks thsi host as unresponsive if applicable
request.markUnresponsiveAddress(input, err)
request.markHostError(input, err)
if request.isUnresponsiveAddress(input) {
// stop all inflight requests
spmHandler.Cancel()
Expand Down Expand Up @@ -551,12 +542,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
}
if execReqErr != nil {
// if applicable mark the host as unresponsive
request.markUnresponsiveAddress(updatedInput, execReqErr)
requestErr = errorutil.NewWithErr(execReqErr).Msgf("got err while executing %v", generatedHttpRequest.URL())
request.options.Progress.IncrementFailedRequestsBy(1)
} else {
request.options.Progress.IncrementRequests()
}
request.markHostError(updatedInput, execReqErr)

// If this was a match, and we want to stop at first match, skip all further requests.
shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch
Expand Down Expand Up @@ -1199,11 +1190,8 @@ func (request *Request) newContext(input *contextargs.Context) context.Context {
return input.Context()
}

// markUnresponsiveAddress checks if the error is a unreponsive host error and marks it
func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err error) {
if err == nil {
return
}
// markHostError checks if the error is a unreponsive host error and marks it
func (request *Request) markHostError(input *contextargs.Context, err error) {
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
}
Expand Down
Loading
Loading