Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added support for context cancellation to engine #5096

Merged
merged 11 commits into from
Apr 25, 2024
2 changes: 1 addition & 1 deletion cmd/integration-test/library.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func executeNucleiAsLibrary(templatePath, templateURL string) ([]string, error)
}
store.Load()

_ = engine.Execute(store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
_ = engine.Execute(context.Background(), store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
engine.WorkPool().Wait() // Wait for the scan to finish

return results, nil
Expand Down
4 changes: 3 additions & 1 deletion internal/runner/lazy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package runner

import (
"context"
"fmt"

"github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx"
Expand Down Expand Up @@ -71,7 +72,8 @@ func GetLazyAuthFetchCallback(opts *AuthLazyFetchOptions) authx.LazyFetchSecret
tmpl := tmpls[0]
// add args to tmpl here
vars := map[string]interface{}{}
ctx := scan.NewScanContext(contextargs.NewWithInput(d.Input))
mainCtx := context.Background()
ctx := scan.NewScanContext(mainCtx, contextargs.NewWithInput(mainCtx, d.Input))
for _, v := range d.Variables {
vars[v.Key] = v.Value
ctx.Input.Add(v.Key, v.Value)
Expand Down
2 changes: 1 addition & 1 deletion internal/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ func (r *Runner) executeTemplatesInput(store *loader.Store, engine *core.Engine)
if r.inputProvider == nil {
return nil, errors.New("no input provider found")
}
results := engine.ExecuteScanWithOpts(finalTemplates, r.inputProvider, r.options.DisableClustering)
results := engine.ExecuteScanWithOpts(context.Background(), finalTemplates, r.inputProvider, r.options.DisableClustering)
return results, nil
}

Expand Down
2 changes: 1 addition & 1 deletion lib/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOpts(targets []string, opts ..
engine := core.New(tmpEngine.opts)
engine.SetExecuterOptions(unsafeOpts.executerOpts)

_ = engine.ExecuteScanWithOpts(store.Templates(), inputProvider, false)
_ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false)

engine.WorkPool().Wait()
return nil
Expand Down
3 changes: 2 additions & 1 deletion lib/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nuclei
import (
"bufio"
"bytes"
"context"
"io"

"github.com/projectdiscovery/nuclei/v3/pkg/authprovider"
Expand Down Expand Up @@ -210,7 +211,7 @@ func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.Result
}
e.resultCallbacks = append(e.resultCallbacks, filtered...)

_ = e.engine.ExecuteScanWithOpts(e.store.Templates(), e.inputProvider, false)
_ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false)
defer e.engine.WorkPool().Wait()
return nil
}
Expand Down
37 changes: 25 additions & 12 deletions pkg/core/execute_options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"sync"
"sync/atomic"

Expand All @@ -20,18 +21,18 @@ import (
//
// All the execution logic for the templates/workflows happens in this part
// of the engine.
func (e *Engine) Execute(templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
return e.ExecuteScanWithOpts(templates, target, false)
func (e *Engine) Execute(ctx context.Context, templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
return e.ExecuteScanWithOpts(ctx, templates, target, false)
}

// ExecuteWithResults a list of templates with results
func (e *Engine) ExecuteWithResults(templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
func (e *Engine) ExecuteWithResults(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
e.Callback = callback
return e.ExecuteScanWithOpts(templatesList, target, false)
return e.ExecuteScanWithOpts(ctx, templatesList, target, false)
}

// ExecuteScanWithOpts executes scan with given scanStrategy
func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
func (e *Engine) ExecuteScanWithOpts(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
results := &atomic.Bool{}
selfcontainedWg := &sync.WaitGroup{}

Expand Down Expand Up @@ -83,14 +84,14 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
}

// Execute All SelfContained in parallel
e.executeAllSelfContained(selfContained, results, selfcontainedWg)
e.executeAllSelfContained(ctx, selfContained, results, selfcontainedWg)

strategyResult := &atomic.Bool{}
switch e.options.ScanStrategy {
case scanstrategy.TemplateSpray.String():
strategyResult = e.executeTemplateSpray(filtered, target)
strategyResult = e.executeTemplateSpray(ctx, filtered, target)
case scanstrategy.HostSpray.String():
strategyResult = e.executeHostSpray(filtered, target)
strategyResult = e.executeHostSpray(ctx, filtered, target)
}

results.CompareAndSwap(false, strategyResult.Load())
Expand All @@ -100,14 +101,20 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
}

// executeTemplateSpray executes scan using template spray strategy where targets are iterated over each template
func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{}

// wp is workpool that contains different waitgroups for
// headless and non-headless templates
wp := e.GetWorkPool()

for _, template := range templatesList {
select {
case <-ctx.Done():
return results
default:
}

// resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())

Expand All @@ -125,23 +132,29 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
// All other request types are executed here
// Note: executeTemplateWithTargets creates goroutines and blocks
// given template is executed on all targets
e.executeTemplateWithTargets(tpl, target, results)
e.executeTemplateWithTargets(ctx, tpl, target, results)
}(template)
}
wp.Wait()
return results
}

// executeHostSpray executes scan using host spray strategy where templates are iterated over each target
func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
func (e *Engine) executeHostSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{}
wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize))

target.Iterate(func(value *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false
default:
}

wp.Add()
go func(targetval *contextargs.MetaInput) {
defer wp.Done()
e.executeTemplatesOnTarget(templatesList, targetval, results)
e.executeTemplatesOnTarget(ctx, templatesList, targetval, results)
}(value)
return true
})
Expand Down
35 changes: 25 additions & 10 deletions pkg/core/executors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"sync"
"sync/atomic"

Expand All @@ -17,14 +18,14 @@ import (
// Executors are low level executors that deals with template execution on a target

// executeAllSelfContained executes all self contained templates that do not use `target`
func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
for _, v := range alltemplates {
sg.Add(1)
go func(template *templates.Template) {
defer sg.Done()
var err error
var match bool
ctx := scan.NewScanContext(contextargs.New())
ctx := scan.NewScanContext(ctx, contextargs.New(ctx))
if e.Callback != nil {
if results, err := template.Executer.ExecuteWithResults(ctx); err != nil {
for _, result := range results {
Expand All @@ -45,7 +46,7 @@ func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, res
}

// executeTemplateWithTarget executes a given template on x targets (with a internal targetpool(i.e concurrency))
func (e *Engine) executeTemplateWithTargets(template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
// this is target pool i.e max target to execute
wg := e.workPool.InputPool(template.Type())

Expand Down Expand Up @@ -77,6 +78,12 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
}

target.Iterate(func(scannedValue *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false // exit
default:
}

// Best effort to track the host progression
// skips indexes lower than the minimum in-flight at interruption time
var skip bool
Expand Down Expand Up @@ -114,9 +121,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target

var match bool
var err error
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
Expand Down Expand Up @@ -149,7 +156,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
}

// executeTemplatesOnTarget execute given templates on given single target
func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
// all templates are executed on single target

// wp is workpool that contains different waitgroups for
Expand All @@ -158,6 +165,12 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
wp := e.GetWorkPool()

for _, tpl := range alltemplates {
select {
case <-ctx.Done():
return
default:
}

// resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())

Expand All @@ -173,9 +186,9 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta

var match bool
var err error
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
Expand Down Expand Up @@ -230,9 +243,11 @@ func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs
go func(tpl *templates.Template) {
defer wg.Done()

ctxArgs := contextargs.New()
// TODO: Workflows are a no-op for now. We need to
// implement them in the future with context cancellation
ctxArgs := contextargs.New(context.Background())
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(context.Background(), ctxArgs)
match, err := template.Executer.Execute(ctx)
if err != nil {
gologger.Warning().Msgf("[%s] Could not execute step: %s\n", e.e.executerOpts.Colorizer.BrightBlue(template.ID), err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/core/workflow_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b

// at this point we should be at the start root execution of a workflow tree, hence we create global shared instances
workflowCookieJar, _ := cookiejar.New(nil)
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx.Context())
ctxArgs.MetaInput = ctx.Input.MetaInput
ctxArgs.CookieJar = workflowCookieJar

Expand Down Expand Up @@ -139,7 +139,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
defer swg.Done()

// create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input)
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(subtemplate, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, subtemplate.Template, err)
}
Expand All @@ -165,7 +165,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan

go func(template *workflows.WorkflowTemplate) {
// create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input)
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(template, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err)
}
Expand Down
25 changes: 13 additions & 12 deletions pkg/core/workflow_execute_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"testing"

"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
Expand All @@ -25,8 +26,8 @@ func TestWorkflowsSimple(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")
}
Expand All @@ -49,8 +50,8 @@ func TestWorkflowsSimpleMultiple(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")

Expand All @@ -77,8 +78,8 @@ func TestWorkflowsSubtemplates(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")

Expand All @@ -103,8 +104,8 @@ func TestWorkflowsSubtemplatesNoMatch(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value")

Expand Down Expand Up @@ -134,8 +135,8 @@ func TestWorkflowsSubtemplatesWithMatcher(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")

Expand Down Expand Up @@ -165,8 +166,8 @@ func TestWorkflowsSubtemplatesWithMatcherNoMatch(t *testing.T) {
}}

engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value")

Expand Down
4 changes: 4 additions & 0 deletions pkg/input/provider/list/hmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package list
import (
"net"
"os"
"runtime"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -77,6 +78,9 @@ func (m *mockDnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}

func Test_scanallips_normalizeStoreInputValue(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test see: https://github.com/projectdiscovery/nuclei/issues/5097")
}
srv := &dns.Server{Addr: ":" + strconv.Itoa(61234), Net: "udp"}
srv.Handler = &mockDnsHandler{}

Expand Down
Loading
Loading