diff --git a/pkg/protocols/dns/cluster.go b/pkg/protocols/dns/cluster.go index 303a287be3..fd501e7fa2 100644 --- a/pkg/protocols/dns/cluster.go +++ b/pkg/protocols/dns/cluster.go @@ -1,17 +1,17 @@ package dns +import ( + "fmt" + + "github.com/cespare/xxhash" +) + // CanCluster returns true if the request can be clustered. // // This used by the clustering engine to decide whether two requests // are similar enough to be considered one and can be checked by // just adding the matcher/extractors for the request and the correct IDs. func (request *Request) CanCluster(other *Request) bool { - if request.Name != other.Name || - request.class != other.class || - request.Retries != other.Retries || - request.question != other.question { - return false - } if request.Recursion != nil { if other.Recursion == nil { return false @@ -23,6 +23,11 @@ func (request *Request) CanCluster(other *Request) bool { return true } +func (request *Request) ClusterHash() uint64 { + inp := fmt.Sprintf("%s-%d-%d-%d", request.Name, request.class, request.Retries, request.question) + return xxhash.Sum64String(inp) +} + func (request *Request) IsClusterable() bool { return !(len(request.Resolvers) > 0 || request.Trace || request.ID != "") } diff --git a/pkg/protocols/http/cluster.go b/pkg/protocols/http/cluster.go index 78bd7c6cb5..bdc91d716e 100644 --- a/pkg/protocols/http/cluster.go +++ b/pkg/protocols/http/cluster.go @@ -1,7 +1,10 @@ package http import ( - sliceutil "github.com/projectdiscovery/utils/slice" + "fmt" + "strings" + + "github.com/cespare/xxhash" "golang.org/x/exp/maps" ) @@ -11,21 +14,13 @@ import ( // are similar enough to be considered one and can be checked by // just adding the matcher/extractors for the request and the correct IDs. func (request *Request) CanCluster(other *Request) bool { - if request.Method != other.Method || - request.MaxRedirects != other.MaxRedirects || - request.DisableCookie != other.DisableCookie || - request.Redirects != other.Redirects { - return false - } - if !sliceutil.Equal(request.Path, other.Path) { - return false - } - if !maps.Equal(request.Headers, other.Headers) { - return false - } - return true + return maps.Equal(request.Headers, other.Headers) } +func (request *Request) ClusterHash() uint64 { + inp := fmt.Sprintf("%s-%d-%t-%t-%s", request.Method.String(), request.MaxRedirects, request.DisableCookie, request.Redirects, strings.Join(request.Path, "-")) + return xxhash.Sum64String(inp) +} func (request *Request) IsClusterable() bool { return !(len(request.Payloads) > 0 || len(request.Fuzzing) > 0 || len(request.Raw) > 0 || len(request.Body) > 0 || request.Unsafe || request.NeedsRequestCondition() || request.Name != "") diff --git a/pkg/protocols/ssl/ssl.go b/pkg/protocols/ssl/ssl.go index dad9635ab5..a8ceb4cfb6 100644 --- a/pkg/protocols/ssl/ssl.go +++ b/pkg/protocols/ssl/ssl.go @@ -5,6 +5,7 @@ import ( "net" "time" + "github.com/cespare/xxhash" "github.com/fatih/structs" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" @@ -102,12 +103,14 @@ type Request struct { // CanCluster returns true if the request can be clustered. func (request *Request) CanCluster(other *Request) bool { - if request.Address != other.Address || request.ScanMode != other.ScanMode { - return false - } return true } +func (request *Request) ClusterHash() uint64 { + inp := fmt.Sprintf("%s-%s", request.Address, request.ScanMode) + return xxhash.Sum64String(inp) +} + func (request *Request) IsClusterable() bool { return !(len(request.CipherSuites) > 0 || request.MinVersion != "" || request.MaxVersion != "") } diff --git a/pkg/templates/cluster.go b/pkg/templates/cluster.go index dcbf134de3..d84626d3bb 100644 --- a/pkg/templates/cluster.go +++ b/pkg/templates/cluster.go @@ -41,9 +41,9 @@ import ( // Finally, the engine creates a single executer with a clusteredexecuter for all templates // in a cluster. func Cluster(list []*Template) [][]*Template { - http := make(map[int]*Template) - dns := make(map[int]*Template) - ssl := make(map[int]*Template) + http := make(map[uint64]map[int]*Template) + dns := make(map[uint64]map[int]*Template) + ssl := make(map[uint64]map[int]*Template) final := [][]*Template{} @@ -58,19 +58,31 @@ func Cluster(list []*Template) [][]*Template { switch { case len(template.RequestsDNS) == 1: if template.RequestsDNS[0].IsClusterable() { - dns[key] = template + hash := template.RequestsDNS[0].ClusterHash() + if dns[hash] == nil { + dns[hash] = map[int]*Template{} + } + dns[hash][key] = template } else { final = append(final, []*Template{template}) } case len(template.RequestsHTTP) == 1: if template.RequestsHTTP[0].IsClusterable() { - http[key] = template + hash := template.RequestsHTTP[0].ClusterHash() + if http[hash] == nil { + http[hash] = map[int]*Template{} + } + http[hash][key] = template } else { final = append(final, []*Template{template}) } case len(template.RequestsSSL) == 1: if template.RequestsSSL[0].IsClusterable() { - ssl[key] = template + hash := template.RequestsSSL[0].ClusterHash() + if ssl[hash] == nil { + ssl[hash] = map[int]*Template{} + } + ssl[hash][key] = template } else { final = append(final, []*Template{template}) } @@ -79,43 +91,48 @@ func Cluster(list []*Template) [][]*Template { } } - // Cluster together dns, http and ssl individually - - for key, template := range dns { - cluster := []*Template{template} - delete(dns, key) - for otherKey, other := range dns { - if template.RequestsDNS[0].CanCluster(other.RequestsDNS[0]) { - delete(dns, otherKey) - cluster = append(cluster, other) + for _, templates := range dns { + for key, template := range templates { + cluster := []*Template{template} + delete(templates, key) + for otherKey, other := range templates { + if template.RequestsDNS[0].CanCluster(other.RequestsDNS[0]) { + cluster = append(cluster, other) + delete(templates, otherKey) + } } + final = append(final, cluster) } - final = append(final, cluster) } - for key, template := range http { - cluster := []*Template{template} - delete(http, key) - for otherKey, other := range http { - if template.RequestsHTTP[0].CanCluster(other.RequestsHTTP[0]) { - delete(http, otherKey) - cluster = append(cluster, other) + for _, templates := range http { + for key, template := range templates { + cluster := []*Template{template} + delete(templates, key) + for otherKey, other := range templates { + if template.RequestsHTTP[0].CanCluster(other.RequestsHTTP[0]) { + cluster = append(cluster, other) + delete(templates, otherKey) + } } + final = append(final, cluster) } - final = append(final, cluster) } - for key, template := range ssl { - cluster := []*Template{template} - delete(ssl, key) - for otherKey, other := range ssl { - if template.RequestsSSL[0].CanCluster(other.RequestsSSL[0]) { - delete(ssl, otherKey) - cluster = append(cluster, other) + for _, templates := range ssl { + for key, template := range templates { + delete(templates, key) + cluster := []*Template{template} + for otherKey, other := range templates { + if template.RequestsSSL[0].CanCluster(other.RequestsSSL[0]) { + cluster = append(cluster, other) + delete(templates, otherKey) + } } + final = append(final, cluster) } - final = append(final, cluster) } + return final }