From 2bbe81a620d7ab617c7dbfa57a03196a674e745d Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Wed, 13 Mar 2024 15:06:48 +0900 Subject: [PATCH] wazevo(frontend): allocation free initializeCurrentBlockKnownBounds (#2154) Signed-off-by: Takeshi Yoneda --- internal/engine/wazevo/frontend/frontend.go | 78 ++++++++++++------- internal/engine/wazevo/frontend/sort_id.go | 15 ++++ .../engine/wazevo/frontend/sort_id_old.go | 17 ++++ 3 files changed, 81 insertions(+), 29 deletions(-) create mode 100644 internal/engine/wazevo/frontend/sort_id.go create mode 100644 internal/engine/wazevo/frontend/sort_id_old.go diff --git a/internal/engine/wazevo/frontend/frontend.go b/internal/engine/wazevo/frontend/frontend.go index d97140658b..8e5a5c3faa 100644 --- a/internal/engine/wazevo/frontend/frontend.go +++ b/internal/engine/wazevo/frontend/frontend.go @@ -3,6 +3,7 @@ package frontend import ( "bytes" + "math" "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" "github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi" @@ -60,6 +61,11 @@ type Compiler struct { varLengthKnownSafeBoundWithIDPool wazevoapi.VarLengthPool[knownSafeBoundWithID] execCtxPtrValue, moduleCtxPtrValue ssa.Value + + // Following are reused for the known safe bounds analysis. + + pointers []int + bounds [][]knownSafeBoundWithID } type ( @@ -498,6 +504,8 @@ func (c *Compiler) finalizeKnownSafeBoundsAtTheEndOfBlock(bID ssa.BasicBlockID) p := &c.varLengthKnownSafeBoundWithIDPool size := len(c.knownSafeBoundsSet) allocated := c.varLengthKnownSafeBoundWithIDPool.Allocate(size) + // Sort the known safe bounds by the value ID so that we can use the intersection algorithm in initializeCurrentBlockKnownBounds. + sortSSAValueIDs(c.knownSafeBoundsSet) for _, vID := range c.knownSafeBoundsSet { kb := c.knownSafeBounds[vID] allocated = allocated.Append(p, knownSafeBoundWithID{ @@ -519,42 +527,54 @@ func (c *Compiler) initializeCurrentBlockKnownBounds() { c.recordKnownSafeBound(kb.id, kb.bound, kb.absoluteAddr) } default: - primary := currentBlk.Pred(0).ID() - type mapVal struct { - kb knownSafeBoundWithID - count int - } - set := map[ssa.ValueID]mapVal{} - for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(primary).View() { - if kb.valid() { - set[kb.id] = mapVal{kb, 1} - } + c.pointers = c.pointers[:0] + c.bounds = c.bounds[:0] + for i := 0; i < preds; i++ { + c.bounds = append(c.bounds, c.getKnownSafeBoundsAtTheEndOfBlocks(currentBlk.Pred(i).ID()).View()) + c.pointers = append(c.pointers, 0) } - // If there are more than one predecessor, we need to find the intersection of the known safe bounds. - for i := 1; i < preds; i++ { - pred := currentBlk.Pred(i).ID() - for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(pred).View() { - if !kb.valid() { - continue + // If there are multiple predecessors, we need to find the intersection of the known safe bounds. + + outer: + for { + smallestID := ssa.ValueID(math.MaxUint32) + for i, ptr := range c.pointers { + if ptr >= len(c.bounds[i]) { + break outer } - mv, ok := set[kb.id] - if !ok { - continue + cb := &c.bounds[i][ptr] + if id := cb.id; id < smallestID { + smallestID = cb.id } - mv.count++ - // Choose the lower bound. - if kb.bound < mv.kb.bound { - mv.kb = kb + } + + // Check if current elements are the same across all lists. + same := true + minBound := uint64(math.MaxUint64) + for i := 0; i < preds; i++ { + cb := &c.bounds[i][c.pointers[i]] + if cb.id != smallestID { + same = false + break + } else { + if cb.bound < minBound { + minBound = cb.bound + } } - set[kb.id] = mv } - } - for _, mv := range set { - if mv.count == preds { - kb := mv.kb + + if same { // All elements are the same. // Absolute address cannot be used in the intersection since the value might be only defined in one of the predecessors. - c.recordKnownSafeBound(kb.id, kb.bound, ssa.ValueInvalid) + c.recordKnownSafeBound(smallestID, minBound, ssa.ValueInvalid) + } + + // Move pointer(s) for the smallest ID forward (if same, move all). + for i := 0; i < preds; i++ { + cb := &c.bounds[i][c.pointers[i]] + if cb.id == smallestID { + c.pointers[i]++ + } } } } diff --git a/internal/engine/wazevo/frontend/sort_id.go b/internal/engine/wazevo/frontend/sort_id.go new file mode 100644 index 0000000000..1296706f5c --- /dev/null +++ b/internal/engine/wazevo/frontend/sort_id.go @@ -0,0 +1,15 @@ +//go:build go1.21 + +package frontend + +import ( + "slices" + + "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" +) + +func sortSSAValueIDs(IDs []ssa.ValueID) { + slices.SortFunc(IDs, func(i, j ssa.ValueID) int { + return int(i) - int(j) + }) +} diff --git a/internal/engine/wazevo/frontend/sort_id_old.go b/internal/engine/wazevo/frontend/sort_id_old.go new file mode 100644 index 0000000000..2e786a160d --- /dev/null +++ b/internal/engine/wazevo/frontend/sort_id_old.go @@ -0,0 +1,17 @@ +//go:build !go1.21 + +// TODO: delete after the floor Go version is 1.21 + +package frontend + +import ( + "sort" + + "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" +) + +func sortSSAValueIDs(IDs []ssa.ValueID) { + sort.SliceStable(IDs, func(i, j int) bool { + return int(IDs[i]) < int(IDs[j]) + }) +}