-
Notifications
You must be signed in to change notification settings - Fork 7
/
dists.go
92 lines (73 loc) · 1.83 KB
/
dists.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package mab
import (
"fmt"
"math"
"gonum.org/v1/gonum/stat/distuv"
)
// Normal is a normal distribution for use with any bandit strategy.
// For the purposes of Thompson sampling, it is truncated at mean +/- 4*sigma
func Normal(mu, sigma float64) NormalDist {
return NormalDist{distuv.Normal{Mu: mu, Sigma: sigma}}
}
type NormalDist struct {
distuv.Normal
}
func (n NormalDist) Support() (float64, float64) {
width := 4.0
return n.Mu - width*n.Sigma, n.Mu + width*n.Sigma
}
func (n NormalDist) String() string {
return fmt.Sprintf("Normal(%f,%f)", n.Mu, n.Sigma)
}
// Beta is a beta distribution for use with any bandit strategy.
func Beta(alpha, beta float64) BetaDist {
return BetaDist{distuv.Beta{Alpha: alpha, Beta: beta}}
}
type BetaDist struct {
distuv.Beta
}
func (b BetaDist) Support() (float64, float64) {
return 0, 1
}
func (b BetaDist) String() string {
return fmt.Sprintf("Beta(%f,%f)", b.Beta.Alpha, b.Beta.Beta)
}
// Point is used for reward models that just provide point estimates. Don't use with Thompson sampling.
func Point(mu float64) PointDist {
return PointDist{mu}
}
type PointDist struct {
Mu float64
}
func (p PointDist) Mean() float64 {
return p.Mu
}
func (p PointDist) CDF(x float64) float64 {
if x >= p.Mu {
return 1
}
return 0
}
func (p PointDist) Prob(x float64) float64 {
if x == p.Mu {
return math.NaN()
}
return 0
}
func (p PointDist) Rand() float64 {
return p.Mu
}
func (p PointDist) Support() (float64, float64) {
return p.Mu, p.Mu
}
func (p PointDist) String() string {
if math.IsInf(p.Mu, -1) {
return "Null()"
}
return fmt.Sprintf("Point(%f)", p.Mu)
}
// Null returns a PointDist with mean equal to negative infinity. This is a special value that indicates
// to a Strategy that this arm should get selection probability zero.
func Null() PointDist {
return PointDist{math.Inf(-1)}
}