Skip to content

Commit

Permalink
feat: added support for context cancellation to engine (#5096)
Browse files Browse the repository at this point in the history
* feat: added support for context cancellation to engine

* misc

* feat: added contexts everywhere

* misc

* misc

* use granular http timeouts and increase http timeout to 30s using multiplier

* track response header timeout in mhe

* update responseHeaderTimeout to 5sec

* skip failing windows test

---------

Co-authored-by: Tarun Koyalwar <[email protected]>
  • Loading branch information
Ice3man543 and tarunKoyalwar authored Apr 25, 2024
1 parent 3dfcec0 commit 0b82e8b
Show file tree
Hide file tree
Showing 40 changed files with 279 additions and 113 deletions.
2 changes: 1 addition & 1 deletion cmd/integration-test/library.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func executeNucleiAsLibrary(templatePath, templateURL string) ([]string, error)
}
store.Load()

_ = engine.Execute(store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
_ = engine.Execute(context.Background(), store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
engine.WorkPool().Wait() // Wait for the scan to finish

return results, nil
Expand Down
4 changes: 3 additions & 1 deletion internal/runner/lazy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package runner

import (
"context"
"fmt"

"github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx"
Expand Down Expand Up @@ -71,7 +72,8 @@ func GetLazyAuthFetchCallback(opts *AuthLazyFetchOptions) authx.LazyFetchSecret
tmpl := tmpls[0]
// add args to tmpl here
vars := map[string]interface{}{}
ctx := scan.NewScanContext(contextargs.NewWithInput(d.Input))
mainCtx := context.Background()
ctx := scan.NewScanContext(mainCtx, contextargs.NewWithInput(mainCtx, d.Input))
for _, v := range d.Variables {
vars[v.Key] = v.Value
ctx.Input.Add(v.Key, v.Value)
Expand Down
2 changes: 1 addition & 1 deletion internal/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ func (r *Runner) executeTemplatesInput(store *loader.Store, engine *core.Engine)
if r.inputProvider == nil {
return nil, errors.New("no input provider found")
}
results := engine.ExecuteScanWithOpts(finalTemplates, r.inputProvider, r.options.DisableClustering)
results := engine.ExecuteScanWithOpts(context.Background(), finalTemplates, r.inputProvider, r.options.DisableClustering)
return results, nil
}

Expand Down
2 changes: 1 addition & 1 deletion lib/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOpts(targets []string, opts ..
engine := core.New(tmpEngine.opts)
engine.SetExecuterOptions(unsafeOpts.executerOpts)

_ = engine.ExecuteScanWithOpts(store.Templates(), inputProvider, false)
_ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false)

engine.WorkPool().Wait()
return nil
Expand Down
3 changes: 2 additions & 1 deletion lib/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nuclei
import (
"bufio"
"bytes"
"context"
"io"

"github.com/projectdiscovery/nuclei/v3/pkg/authprovider"
Expand Down Expand Up @@ -210,7 +211,7 @@ func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.Result
}
e.resultCallbacks = append(e.resultCallbacks, filtered...)

_ = e.engine.ExecuteScanWithOpts(e.store.Templates(), e.inputProvider, false)
_ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false)
defer e.engine.WorkPool().Wait()
return nil
}
Expand Down
37 changes: 25 additions & 12 deletions pkg/core/execute_options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"sync"
"sync/atomic"

Expand All @@ -20,18 +21,18 @@ import (
//
// All the execution logic for the templates/workflows happens in this part
// of the engine.
func (e *Engine) Execute(templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
return e.ExecuteScanWithOpts(templates, target, false)
func (e *Engine) Execute(ctx context.Context, templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
return e.ExecuteScanWithOpts(ctx, templates, target, false)
}

// ExecuteWithResults a list of templates with results
func (e *Engine) ExecuteWithResults(templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
func (e *Engine) ExecuteWithResults(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
e.Callback = callback
return e.ExecuteScanWithOpts(templatesList, target, false)
return e.ExecuteScanWithOpts(ctx, templatesList, target, false)
}

// ExecuteScanWithOpts executes scan with given scanStrategy
func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
func (e *Engine) ExecuteScanWithOpts(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
results := &atomic.Bool{}
selfcontainedWg := &sync.WaitGroup{}

Expand Down Expand Up @@ -83,14 +84,14 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
}

// Execute All SelfContained in parallel
e.executeAllSelfContained(selfContained, results, selfcontainedWg)
e.executeAllSelfContained(ctx, selfContained, results, selfcontainedWg)

strategyResult := &atomic.Bool{}
switch e.options.ScanStrategy {
case scanstrategy.TemplateSpray.String():
strategyResult = e.executeTemplateSpray(filtered, target)
strategyResult = e.executeTemplateSpray(ctx, filtered, target)
case scanstrategy.HostSpray.String():
strategyResult = e.executeHostSpray(filtered, target)
strategyResult = e.executeHostSpray(ctx, filtered, target)
}

results.CompareAndSwap(false, strategyResult.Load())
Expand All @@ -100,14 +101,20 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
}

// executeTemplateSpray executes scan using template spray strategy where targets are iterated over each template
func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{}

// wp is workpool that contains different waitgroups for
// headless and non-headless templates
wp := e.GetWorkPool()

for _, template := range templatesList {
select {
case <-ctx.Done():
return results
default:
}

// resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())

Expand All @@ -125,23 +132,29 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
// All other request types are executed here
// Note: executeTemplateWithTargets creates goroutines and blocks
// given template is executed on all targets
e.executeTemplateWithTargets(tpl, target, results)
e.executeTemplateWithTargets(ctx, tpl, target, results)
}(template)
}
wp.Wait()
return results
}

// executeHostSpray executes scan using host spray strategy where templates are iterated over each target
func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
func (e *Engine) executeHostSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{}
wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize))

target.Iterate(func(value *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false
default:
}

wp.Add()
go func(targetval *contextargs.MetaInput) {
defer wp.Done()
e.executeTemplatesOnTarget(templatesList, targetval, results)
e.executeTemplatesOnTarget(ctx, templatesList, targetval, results)
}(value)
return true
})
Expand Down
35 changes: 25 additions & 10 deletions pkg/core/executors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"sync"
"sync/atomic"

Expand All @@ -17,14 +18,14 @@ import (
// Executors are low level executors that deals with template execution on a target

// executeAllSelfContained executes all self contained templates that do not use `target`
func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
for _, v := range alltemplates {
sg.Add(1)
go func(template *templates.Template) {
defer sg.Done()
var err error
var match bool
ctx := scan.NewScanContext(contextargs.New())
ctx := scan.NewScanContext(ctx, contextargs.New(ctx))
if e.Callback != nil {
if results, err := template.Executer.ExecuteWithResults(ctx); err != nil {
for _, result := range results {
Expand All @@ -45,7 +46,7 @@ func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, res
}

// executeTemplateWithTarget executes a given template on x targets (with a internal targetpool(i.e concurrency))
func (e *Engine) executeTemplateWithTargets(template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
// this is target pool i.e max target to execute
wg := e.workPool.InputPool(template.Type())

Expand Down Expand Up @@ -77,6 +78,12 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
}

target.Iterate(func(scannedValue *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false // exit
default:
}

// Best effort to track the host progression
// skips indexes lower than the minimum in-flight at interruption time
var skip bool
Expand Down Expand Up @@ -114,9 +121,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target

var match bool
var err error
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
Expand Down Expand Up @@ -149,7 +156,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
}

// executeTemplatesOnTarget execute given templates on given single target
func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
// all templates are executed on single target

// wp is workpool that contains different waitgroups for
Expand All @@ -158,6 +165,12 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
wp := e.GetWorkPool()

for _, tpl := range alltemplates {
select {
case <-ctx.Done():
return
default:
}

// resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())

Expand All @@ -173,9 +186,9 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta

var match bool
var err error
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
Expand Down Expand Up @@ -230,9 +243,11 @@ func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs
go func(tpl *templates.Template) {
defer wg.Done()

ctxArgs := contextargs.New()
// TODO: Workflows are a no-op for now. We need to
// implement them in the future with context cancellation
ctxArgs := contextargs.New(context.Background())
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(context.Background(), ctxArgs)
match, err := template.Executer.Execute(ctx)
if err != nil {
gologger.Warning().Msgf("[%s] Could not execute step: %s\n", e.e.executerOpts.Colorizer.BrightBlue(template.ID), err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/core/workflow_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b

// at this point we should be at the start root execution of a workflow tree, hence we create global shared instances
workflowCookieJar, _ := cookiejar.New(nil)
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx.Context())
ctxArgs.MetaInput = ctx.Input.MetaInput
ctxArgs.CookieJar = workflowCookieJar

Expand Down Expand Up @@ -139,7 +139,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
defer swg.Done()

// create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input)
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(subtemplate, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, subtemplate.Template, err)
}
Expand All @@ -165,7 +165,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan

go func(template *workflows.WorkflowTemplate) {
// create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input)
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(template, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err)
}
Expand Down
25 changes: 13 additions & 12 deletions pkg/core/workflow_execute_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"testing"

"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
Expand All @@ -25,8 +26,8 @@ func TestWorkflowsSimple(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")
}
Expand All @@ -49,8 +50,8 @@ func TestWorkflowsSimpleMultiple(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")

Expand All @@ -77,8 +78,8 @@ func TestWorkflowsSubtemplates(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")

Expand All @@ -103,8 +104,8 @@ func TestWorkflowsSubtemplatesNoMatch(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value")

Expand Down Expand Up @@ -134,8 +135,8 @@ func TestWorkflowsSubtemplatesWithMatcher(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")

Expand Down Expand Up @@ -165,8 +166,8 @@ func TestWorkflowsSubtemplatesWithMatcherNoMatch(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value")

Expand Down
4 changes: 4 additions & 0 deletions pkg/input/provider/list/hmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package list
import (
"net"
"os"
"runtime"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -77,6 +78,9 @@ func (m *mockDnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}

func Test_scanallips_normalizeStoreInputValue(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test see: https://github.com/projectdiscovery/nuclei/issues/5097")
}
srv := &dns.Server{Addr: ":" + strconv.Itoa(61234), Net: "udp"}
srv.Handler = &mockDnsHandler{}

Expand Down
Loading

0 comments on commit 0b82e8b

Please sign in to comment.