Skip to content

Commit

Permalink
parallelize Thompson probability calculations (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
btamadio authored Mar 1, 2021
1 parent 986c7d8 commit 456abc3
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 29 deletions.
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
15 changes: 8 additions & 7 deletions numint/gauss_legendre.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,38 @@ func GaussLegendre(degree int) *GaussLegendreRule {
return &GaussLegendreRule{
abscissae: abscissae,
weightCoeffs: weightCoeffs,
weights: make([]float64, len(weightCoeffs)),
points: make([]float64, len(abscissae)),
}
}

// GaussLegendreRule provides Weights and Points functions for Gauss Legendre quadrature rules.
type GaussLegendreRule struct {
abscissae, weightCoeffs []float64
weights, points []float64
}

// Weights returns the quadrature weights to use for the interval [a, b].
// The number of points returned depends on the degree of the rule.
func (g *GaussLegendreRule) Weights(a float64, b float64) []float64 {

weights := make([]float64, len(g.weightCoeffs))

for i := range g.weightCoeffs {
g.weights[i] = g.weightCoeffs[i] * (b - a) / 2
weights[i] = g.weightCoeffs[i] * (b - a) / 2
}

return g.weights
return weights
}

// Points returns the quadrature sampling points to use for the interval [a, b].
// The number of points returned depends on the degree of the rule.
func (g GaussLegendreRule) Points(a float64, b float64) []float64 {

points := make([]float64, len(g.abscissae))

for i := range g.abscissae {
g.points[i] = g.abscissae[i]*(b-a)/2 + (b+a)/2
points[i] = g.abscissae[i]*(b-a)/2 + (b+a)/2
}

return g.points
return points
}

// source: http://www.holoborodko.com/pavel/numerical-methods/numerical-integration/
Expand Down
70 changes: 48 additions & 22 deletions thompson.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package mab

import (
"sync"
)

func NewThompson(integrator Integrator) *Thompson {
return &Thompson{
integrator: integrator,
Expand All @@ -8,8 +12,6 @@ func NewThompson(integrator Integrator) *Thompson {

type Thompson struct {
integrator Integrator
rewards []Dist
probs []float64
}

type Integrator interface {
Expand All @@ -21,39 +23,63 @@ func (t *Thompson) ComputeProbs(rewards []Dist) ([]float64, error) {
return []float64{}, nil
}

t.rewards = rewards
return t.computeProbs()
integrals := t.integrals(rewards)
return t.integrateParallel(integrals)
}

func (t *Thompson) computeProbs() ([]float64, error) {
t.probs = make([]float64, len(t.rewards))
for arm := range t.rewards {
prob, err := t.computeProb(arm)
if err != nil {
return nil, err
}
t.probs[arm] = prob
}
return t.probs, nil
type integral struct {
integrand integrand
interval interval
}

func (t *Thompson) computeProb(arm int) (float64, error) {
integrand := t.integrand(arm)
xMin, xMax := t.rewards[arm].Support()
type integrand func(float64) float64
type interval struct{ a, b float64 }

return t.integrator.Integrate(integrand, xMin, xMax)
func (t *Thompson) integrals(rewards []Dist) []integral {
result := make([]integral, len(rewards))
for i := range rewards {
result[i].integrand = t.integrand(rewards, i)
result[i].interval.a, result[i].interval.b = rewards[i].Support()
}
return result
}

func (t *Thompson) integrand(arm int) func(float64) float64 {
func (t *Thompson) integrand(rewards []Dist, arm int) integrand {
return func(x float64) float64 {
total := t.rewards[arm].Prob(x)
for j := range t.rewards {
total := rewards[arm].Prob(x)
for j := range rewards {
if arm == j {
continue
}

total *= t.rewards[j].CDF(x)
total *= rewards[j].CDF(x)
}
return total
}
}

func (t *Thompson) integrateParallel(integrals []integral) ([]float64, error) {
n := len(integrals)

results := make([]float64, n)
errs := make([]error, n)

var wg sync.WaitGroup
for i := 0; i < n; i++ {
wg.Add(1)
go func(i int, xi integral) {
results[i], errs[i] = t.integrator.Integrate(xi.integrand, xi.interval.a, xi.interval.b)
wg.Done()
}(i, integrals[i])
}

wg.Wait()

for _, err := range errs {
if err != nil {
return nil, err
}
}

return results, nil
}

0 comments on commit 456abc3

Please sign in to comment.