diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..603f653 --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/codeql-analysis.yaml b/.github/workflows/codeql-analysis.yaml new file mode 100644 index 0000000..1a995d2 --- /dev/null +++ b/.github/workflows/codeql-analysis.yaml @@ -0,0 +1,70 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ main ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ main ] + schedule: + - cron: '41 14 * * 2' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'go' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v3 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml new file mode 100644 index 0000000..51dced3 --- /dev/null +++ b/.github/workflows/coverage.yaml @@ -0,0 +1,14 @@ +name: coverage +on: push + +jobs: + coverage: + runs-on: ubuntu-latest + name: Go test coverage + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "stable" + - run: go test -coverprofile=coverage.txt -covermode=atomic + - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..a4c3bc1 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,36 @@ +name: goreleaser + +on: + push: + tags: + - "v*" + +permissions: + contents: write + +jobs: + test: + uses: ./.github/workflows/test.yaml + secrets: inherit + goreleaser: + runs-on: ubuntu-latest + steps: + - + name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - + name: Setup + uses: actions/setup-go@v5 + with: + go-version: '>=1.22' + - + name: Run GoReleaser + uses: goreleaser/goreleaser-action@v5 + with: + distribution: goreleaser + version: ${{ env.GITHUB_REF_NAME }} + args: release --rm-dist + env: + GITHUB_TOKEN: ${{ secrets.WEIGHTEDOPTION_RELEASE_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..219eadc --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,30 @@ +name: test + +on: + push: + branches: + - '*' + tags-ignore: + - '*' + pull_request: + workflow_call: + +jobs: + test: + runs-on: ubuntu-latest + name: Tests + steps: + - + name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - + name: Setup + uses: actions/setup-go@v5 + with: + go-version: '>=1.22' + + - + name: Test + run: go test --race --shuffle on ./... \ No newline at end of file diff --git a/.gitignore b/.gitignore index 3b735ec..93e8ba5 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,9 @@ # Go workspace file go.work + +# MacOS +.DS_Store + +# GoLand +.idea/ diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..e14141e --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,16 @@ +before: + hooks: + - go mod tidy + +builds: + - skip: true + +release: + prerelease: auto + +changelog: + sort: asc + filters: + exclude: + - '^docs:' + - '^test:' \ No newline at end of file diff --git a/README.md b/README.md index fdf6c76..f9c4f8c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,43 @@ # weightedoption + A Go package for weighted random option selection + +## Example Usage + +```go +package main + +import ( + "fmt" + "log" + + "github.com/eljamo/weightedoption" +) + +// Simulates 100 chances for dropping a raid exotic weapon from a Destiny which has a 5% drop chance when a player completes the raid +func main() { + s, err := weightedoption.NewSelector( + weightedoption.NewOption('🔫', 5), + weightedoption.NewOption('❌', 95), + ) + if err != nil { + log.Fatal(err) + } + + chances := make([]rune, 30) + for i := 0; i < len(chances); i++ { + chances[i] = s.Select() + } + fmt.Println(string(chances)) + + tally := make(map[rune]int) + for _, c := range chances { + tally[c]++ + } + + _, err = fmt.Printf("\n🔫: %d\t❌ %d\n", tally['🔫'], tally['❌']) + if err != nil { + log.Fatal(err) + } +} +``` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5e823ef --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/eljamo/weightedoption + +go 1.22.1 diff --git a/weightedoption.go b/weightedoption.go new file mode 100644 index 0000000..92c2900 --- /dev/null +++ b/weightedoption.go @@ -0,0 +1,132 @@ +package weightedoption + +import ( + "errors" + "math" + "math/rand/v2" + "sort" +) + +var ( + ErrWeightOverflow = errors.New("sum of Option weights exceeds total") + ErrNoValidOptions = errors.New("0 Option(s) with Weight >= 1") +) + +// WeightIntegerConstraint is a type constraint for the Weight field of the Option struct. +type WeightIntegerConstraint interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// Option is a struct that holds a data value and its associated weight. +type Option[DataType any, WeightIntegerType WeightIntegerConstraint] struct { + Data DataType + Weight WeightIntegerType +} + +// NewOption creates a new Option. +func NewOption[DataType any, WeightIntegerType WeightIntegerConstraint]( + data DataType, + weight WeightIntegerType, +) Option[DataType, WeightIntegerType] { + return Option[DataType, WeightIntegerType]{Data: data, Weight: weight} +} + +// SearchIntsFuncSignature is the signature of the function used to search for an integer in a sorted slice of integers. +type SearchIntsFuncSignature func(runningTotalWeights []int, randInt int) int + +// Selector is a struct that holds a slice of Options and their cumulative weights. +type Selector[DataType any, WeightIntegerType WeightIntegerConstraint] struct { + options []Option[DataType, WeightIntegerType] + runningTotalWeights []int + totalWeight int + searchIntsFunc SearchIntsFuncSignature +} + +// NewSelector creates a new Selector. +func NewSelector[DataType any, WeightIntegerType WeightIntegerConstraint]( + options ...Option[DataType, WeightIntegerType], +) (*Selector[DataType, WeightIntegerType], error) { + var filteredOptions []Option[DataType, WeightIntegerType] + for _, opt := range options { + if opt.Weight > 0 { + filteredOptions = append(filteredOptions, opt) + } + } + + sort.Slice(filteredOptions, func(i, j int) bool { + return filteredOptions[i].Weight < filteredOptions[j].Weight + }) + + runningTotalWeights := make([]int, len(filteredOptions)) + totalWeight := 0 + + for i, opt := range filteredOptions { + if uint(opt.Weight) >= math.MaxInt { + return nil, ErrWeightOverflow + } + + weight := int(opt.Weight) + if weight > math.MaxInt-totalWeight { + return nil, ErrWeightOverflow + } + + totalWeight += weight + runningTotalWeights[i] = totalWeight + } + + if totalWeight < 1 { + return nil, ErrNoValidOptions + } + + return &Selector[DataType, WeightIntegerType]{ + options: filteredOptions, + runningTotalWeights: runningTotalWeights, + totalWeight: totalWeight, + searchIntsFunc: searchInts, + }, nil +} + +// NewSelectorWithCustomSearchIntsFunc creates a new Selector with a custom searchIntsFunc. +func NewSelectorWithCustomSearchIntsFunc[DataType any, WeightIntegerType WeightIntegerConstraint]( + searchIntsFunc SearchIntsFuncSignature, + options ...Option[DataType, WeightIntegerType], +) (*Selector[DataType, WeightIntegerType], error) { + selector, err := NewSelector(options...) + if err != nil { + return nil, err + } + + selector.searchIntsFunc = searchIntsFunc + return selector, nil +} + +// NewSelectorUsingSortSearchInts creates a new Selector using the sort.SearchInts function. +func NewSelectorUsingSortSearchInts[DataType any, WeightIntegerType WeightIntegerConstraint]( + options ...Option[DataType, WeightIntegerType], +) (*Selector[DataType, WeightIntegerType], error) { + return NewSelectorWithCustomSearchIntsFunc(sort.SearchInts, options...) +} + +// Select returns a single option from the Selector. +func (s Selector[DataType, WeightIntegerType]) Select() DataType { + r := rand.IntN(s.totalWeight) + 1 + i := s.searchIntsFunc(s.runningTotalWeights, r) + return s.options[i].Data +} + +// searchInts searches for the index of the first element in runningTotalWeights +// that is greater than or equal to randInt. The slice must be sorted in +// ascending order. +func searchInts(runningTotalWeights []int, randInt int) int { + start, end := 0, len(runningTotalWeights) + for start < end { + mid := int(uint(start+end) >> 1) + if runningTotalWeights[mid] < randInt { + start = mid + 1 + } else { + end = mid + } + } + + return start +} diff --git a/weightedoption_32bit_test.go b/weightedoption_32bit_test.go new file mode 100644 index 0000000..4cf8ed6 --- /dev/null +++ b/weightedoption_32bit_test.go @@ -0,0 +1,34 @@ +//go:build 386 || arm || mips || mipsle +// +build 386 arm mips mipsle + +package weightedoption + +import ( + "math" + "testing" +) + +func TestNewSelector32Bit(t *testing.T) { + t.Parallel() + + u32tests := []struct { + name string + cs []Option[rune, uint32] + wantErr error + }{ + { + name: "weight overflow from single uint32 exceeding system math.MaxInt", + cs: []Option[rune, uint32]{{Data: 'a', Weight: uint32(math.MaxInt) + 1}}, + wantErr: ErrWeightOverflow, + }, + } + + for _, tt := range u32tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewSelector(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewSelector() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/weightedoption_64bit_test.go b/weightedoption_64bit_test.go new file mode 100644 index 0000000..f634527 --- /dev/null +++ b/weightedoption_64bit_test.go @@ -0,0 +1,34 @@ +//go:build amd64 || arm64 || mips64 || mips64le || ppc64 || ppc64le || riscv64 || s390x || wasm +// +build amd64 arm64 mips64 mips64le ppc64 ppc64le riscv64 s390x wasm + +package weightedoption + +import ( + "math" + "testing" +) + +func TestNewSelector64Bit(t *testing.T) { + t.Parallel() + + u64tests := []struct { + name string + cs []Option[rune, uint64] + wantErr error + }{ + { + name: "weight overflow from single uint64 exceeding system math.MaxInt", + cs: []Option[rune, uint64]{{Data: 'a', Weight: uint64(math.MaxInt) + 1}}, + wantErr: ErrWeightOverflow, + }, + } + + for _, tt := range u64tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewSelector(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewSelector() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/weightedoption_test.go b/weightedoption_test.go new file mode 100644 index 0000000..09deeb4 --- /dev/null +++ b/weightedoption_test.go @@ -0,0 +1,269 @@ +package weightedoption + +import ( + "fmt" + "math" + "math/rand" + "testing" +) + +const ( + testOptions int = 10 + testIterations int = 1_000_000 +) + +func TestNewSelector(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cs []Option[rune, int] + wantErr error + }{ + { + name: "no options", + cs: []Option[rune, int]{}, + wantErr: ErrNoValidOptions, + }, + { + name: "no options with weight greater than 0", + cs: []Option[rune, int]{{Data: 'a', Weight: 0}, {Data: 'b', Weight: 0}}, + wantErr: ErrNoValidOptions, + }, + { + name: "one option with weight greater than 0", + cs: []Option[rune, int]{{Data: 'a', Weight: 1}}, + wantErr: nil, + }, + { + name: "weight overflow", + cs: []Option[rune, int]{{Data: 'a', Weight: math.MaxInt/2 + 1}, {Data: 'b', Weight: math.MaxInt/2 + 1}}, + wantErr: ErrWeightOverflow, + }, + { + name: "nominal case", + cs: []Option[rune, int]{{Data: 'a', Weight: 1}, {Data: 'b', Weight: 2}}, + wantErr: nil, + }, + { + name: "one valid option and one invalid option with negative weight", + cs: []Option[rune, int]{{Data: 'a', Weight: 3}, {Data: 'b', Weight: -2}}, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewSelector(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewSelector() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNewSelectorUsingSortSearchInts(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cs []Option[rune, int] + wantErr error + }{ + { + name: "no options", + cs: []Option[rune, int]{}, + wantErr: ErrNoValidOptions, + }, + { + name: "no options with weight greater than 0", + cs: []Option[rune, int]{{Data: 'a', Weight: 0}, {Data: 'b', Weight: 0}}, + wantErr: ErrNoValidOptions, + }, + { + name: "one option with weight greater than 0", + cs: []Option[rune, int]{{Data: 'a', Weight: 1}}, + wantErr: nil, + }, + { + name: "weight overflow", + cs: []Option[rune, int]{{Data: 'a', Weight: math.MaxInt/2 + 1}, {Data: 'b', Weight: math.MaxInt/2 + 1}}, + wantErr: ErrWeightOverflow, + }, + { + name: "nominal case", + cs: []Option[rune, int]{{Data: 'a', Weight: 1}, {Data: 'b', Weight: 2}}, + wantErr: nil, + }, + { + name: "one valid option and one invalid option with negative weight", + cs: []Option[rune, int]{{Data: 'a', Weight: 3}, {Data: 'b', Weight: -2}}, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewSelectorUsingSortSearchInts(tt.cs...) + if err != tt.wantErr { + t.Errorf("NewSelectorUsingSortSearchInts() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSelector_Select(t *testing.T) { + t.Parallel() + + options := mockFrequencyOptions(t, testOptions) + picker, err := NewSelector(options...) + if err != nil { + t.Fatal("Failed to create Selector:", err) + } + + counts := make(map[int]int) + for i := 0; i < testIterations; i++ { + c := picker.Select() + counts[c]++ + } + + verifyFrequencyCounts(t, counts, options) +} + +func mockFrequencyOptions(t *testing.T, n int) []Option[int, int] { + t.Helper() + options := make([]Option[int, int], 0, n) + for i := 1; i <= n; i++ { + c := NewOption(i, i) + options = append(options, c) + } + t.Log("Mocked options:", options) + return options +} + +func verifyFrequencyCounts(t *testing.T, counts map[int]int, options []Option[int, int]) { + t.Helper() + + for i := 0; i < len(options)-1; i++ { + if counts[options[i].Data] > counts[options[i+1].Data] { + t.Errorf( + "Option with lower weight %d (count: %d) was selected more than option with higher weight %d (count: %d)", + options[i].Weight, counts[options[i].Data], options[i+1].Weight, counts[options[i+1].Data], + ) + } + } +} + +const BMMinOptions int = 10 +const BMMaxOptions int = 10_000_000 + +func BenchmarkNewSelector(b *testing.B) { + for n := BMMinOptions; n <= BMMaxOptions; n *= 10 { + b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) { + options := mockOptions(n) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = NewSelector(options...) + } + }) + } +} + +func BenchmarkSelect(b *testing.B) { + for n := BMMinOptions; n <= BMMaxOptions; n *= 10 { + b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) { + options := mockOptions(n) + selector, err := NewSelector(options...) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = selector.Select() + } + }) + } +} + +func BenchmarkSelectParallel(b *testing.B) { + for n := BMMinOptions; n <= BMMaxOptions; n *= 10 { + b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) { + options := mockOptions(n) + selector, err := NewSelector(options...) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = selector.Select() + } + }) + }) + } +} + +func BenchmarkNewSelectorUsingSortSearchInts(b *testing.B) { + for n := BMMinOptions; n <= BMMaxOptions; n *= 10 { + b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) { + options := mockOptions(n) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = NewSelector(options...) + } + }) + } +} + +func BenchmarkSortSearchIntsSelect(b *testing.B) { + for n := BMMinOptions; n <= BMMaxOptions; n *= 10 { + b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) { + options := mockOptions(n) + selector, err := NewSelector(options...) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = selector.Select() + } + }) + } +} + +func BenchmarkSortSearchIntsSelectParallel(b *testing.B) { + for n := BMMinOptions; n <= BMMaxOptions; n *= 10 { + b.Run(fmt.Sprintf("size=%s", fmt1eN(n)), func(b *testing.B) { + options := mockOptions(n) + selector, err := NewSelector(options...) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = selector.Select() + } + }) + }) + } +} + +func mockOptions(n int) []Option[rune, int] { + options := make([]Option[rune, int], 0, n) + for i := 0; i < n; i++ { + s := 'O' + w := rand.Intn(10) + c := NewOption(s, w) + options = append(options, c) + } + return options +} + +// fmt1eN returns simplified order of magnitude scientific notation for n, +// e.g. "1e2" for 100, "1e7" for 10 million. +func fmt1eN(n int) string { + return fmt.Sprintf("1e%d", int(math.Log10(float64(n)))) +}