-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
290 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import numpy as np | ||
|
||
def log_jeffreys_posterior(X, y, w): | ||
n = len(y) | ||
u = np.dot(X, w) | ||
a = np.zeros(n) | ||
like = 0 | ||
for i in range(n): | ||
ui = u[i][0] | ||
p = 1 / (1.0 + np.exp(-ui)) | ||
q = 1 - p | ||
if y[i] == 1: | ||
like += np.log(p) | ||
else: | ||
like += np.log(q) | ||
a[i] = p * q | ||
H = np.dot(X.T, np.dot(np.diag(a), X)) | ||
L = np.linalg.cholesky(H) | ||
prior = np.sum(np.log(np.diag(L))) | ||
return like + prior |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._errors import plot_error_comparison | ||
|
||
__all__ = [ | ||
'plot_error_comparison', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from collections import defaultdict, namedtuple | ||
import numpy as np | ||
|
||
SummaryStats = namedtuple('SummaryStats', ['min', 'max', 'mean']) | ||
|
||
def compute_range(errs): | ||
low = float("inf") | ||
high = -low | ||
for errs_i in errs: | ||
low = min(np.min(errs_i), low) | ||
high = max(np.max(errs_i), high) | ||
return low, high | ||
|
||
def bin_errors(errs, bin_size, low): | ||
res = defaultdict(int) | ||
for err in errs: | ||
t = np.rint((err - low) / bin_size) | ||
res[t] += 1 | ||
return res | ||
|
||
def summary_stats(errs): | ||
min_err = np.min(errs) | ||
max_err = np.max(errs) | ||
mean_err = np.mean(errs) | ||
return SummaryStats(min_err, max_err, mean_err) | ||
|
||
def points(bins, a, m): | ||
res = [] | ||
for err, cnt in bins.items(): | ||
for i in range(cnt): | ||
res.append((m * i + a, err)) | ||
return np.array(res) | ||
|
||
def plot_error_comparison(ax, left_errs, right_errs, nbins): | ||
left_stats = summary_stats(left_errs) | ||
right_stats = summary_stats(right_errs) | ||
left_errs = np.log(left_errs) | ||
right_errs = np.log(right_errs) | ||
low, high = compute_range((left_errs, right_errs)) | ||
bin_size = (high - low) / nbins | ||
left_bins = bin_errors(left_errs, bin_size, low) | ||
right_bins = bin_errors(right_errs, bin_size, low) | ||
left_pts = points(left_bins, 0, -1) | ||
right_pts = points(right_bins, 1, 1) | ||
|
||
ax.spines['top'].set_visible(False) | ||
ax.spines['right'].set_visible(False) | ||
ax.spines['bottom'].set_visible(False) | ||
ax.xaxis.set_ticks([]) | ||
|
||
yticks = [ | ||
left_stats.mean, | ||
right_stats.mean, | ||
] | ||
for max_err in [left_stats.max, right_stats.max]: | ||
yticks.append(max_err) | ||
yticks.append(np.exp(low)) | ||
ylabels = ['%0.2e' % x for x in yticks] | ||
ax.set_yscale('log') | ||
ax.yaxis.set_ticks(yticks) | ||
ax.yaxis.set_ticklabels(ylabels) | ||
ax.set_ylim(np.exp([low, high])) | ||
ax.yaxis.grid(color='white', linewidth=0.5) | ||
ax.tick_params(direction='in') | ||
|
||
ax.scatter(left_pts[:, 0], np.exp(left_pts[:, 1] * bin_size + low), s=0.5, c='tab:orange') | ||
ax.scatter(right_pts[:, 0], np.exp(right_pts[:, 1] * bin_size + low), s=0.5, c='tab:blue') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from ._normal_mean_hypothesis import NormalMeanHypothesis | ||
|
||
__all__ = [ | ||
NormalMeanHypothesis, | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import numpy as np | ||
import scipy | ||
from collections import namedtuple | ||
|
||
from bbai._computation._bridge import stat | ||
|
||
class _NormalMeanHypothesisResult3: | ||
def __init__(self, z, n, sigma=None): | ||
t = np.sqrt(n - 1) * z | ||
|
||
dist = scipy.stats.t(n - 1) | ||
|
||
pdf_t = dist.pdf(t) | ||
|
||
# self.left_to_star_N corresponds to the value B_{20}^N in [1] | ||
self.left_to_star_N = dist.cdf(-t) | ||
|
||
# self.right_to_star_N corresponds to the value B_{30}^N in [1] | ||
self.right_to_star_N = 1 - self.left_to_star_N | ||
|
||
# self.equal_to_star_N corresponds to the value B_{10}^N in [1] | ||
if sigma is not None: | ||
self.equal_to_star_N = np.sqrt(n-1) * pdf_t / sigma | ||
else: | ||
self.equal_to_star_N = None | ||
|
||
|
||
root2_z = np.sqrt(2) * z | ||
|
||
g1 = stat.normal_eeibf_g1(root2_z) | ||
|
||
# self.correction_left corresponds to the value | ||
# E_{\hat{\theta}}^{H_0}[B_{20}^N(X(l))] | ||
# in [1] sec 2.4.2 | ||
self.correction_left = 0.5 - g1 / np.pi | ||
|
||
# self.correction_right corresponds to the value | ||
# E_{\hat{\theta}}^{H_0}[B_{30}^N(X(l))] | ||
# in [1] sec 2.4.2 | ||
self.correction_right = 0.5 + g1 / np.pi | ||
|
||
# self.factor_left corresponds to the value | ||
# B_{20}^{EEI} | ||
# in [1] sec 2.4.2 | ||
self.factor_left = self.left_to_star_N / self.correction_left | ||
|
||
# self.factor_right corresponds to the value | ||
# B_{30}^{EEI} | ||
# in [1] sec 2.4.2 | ||
self.factor_right = self.right_to_star_N / self.correction_right | ||
|
||
g2 = stat.normal_eeibf_g2(root2_z) | ||
|
||
# self.correction_equal corresponds to the value | ||
# E_{\hat{\theta}}^{H_0}[B_{10}^N(X(l))] | ||
# in [1] sec 2.4.2 | ||
correction_equal = np.sqrt(2) * g2 / np.pi | ||
if sigma is not None: | ||
self.correction_equal = correction_equal / sigma | ||
else: | ||
self.correction_equal = None | ||
|
||
# self.factor_equal corresponds to the value | ||
# B_{10}^{EEI} | ||
# in [1] sec 2.4.2 | ||
self.factor_equal = pdf_t * np.sqrt(n - 1) / correction_equal | ||
|
||
# self.left, self.equal, self.right are the posterior probabilities | ||
# of the three hypotheses | ||
total = self.factor_left + self.factor_equal + self.factor_right | ||
self.left = self.factor_left / total | ||
self.equal = self.factor_equal / total | ||
self.right = self.factor_right / total | ||
|
||
class _NormalMeanHypothesisResult2: | ||
def __init__(self, z, n): | ||
res3 = _NormalMeanHypothesisResult3(z, n) | ||
|
||
self.left_to_star_N = res3.left_to_star_N | ||
self.right_to_star_N = res3.right_to_star_N | ||
|
||
|
||
self.correction_left = res3.correction_left | ||
self.correction_right = res3.correction_right | ||
|
||
self.factor_left = res3.factor_left | ||
self.factor_right = res3.factor_right | ||
|
||
|
||
total = self.factor_left + self.factor_right | ||
self.left = self.factor_left / total | ||
self.right = self.factor_right / total | ||
|
||
class NormalMeanHypothesis: | ||
"""Implements normal mean hypothesis testing with unknown variance using the | ||
encompassing expected intrinsic Bayes factor (EEIBF) method described in the paper | ||
1: Default Bayes Factors for Nonnested Hypothesis Testing | ||
by J. O. Berger and J. Mortera | ||
url: https://www.jstor.org/stable/2670175 | ||
postscript: http://www2.stat.duke.edu/~berger/papers/mortera.ps | ||
The class provides both two tailed hypothesis testing (i.e. H_1: mu < 0, H_2: mu >= 0) and | ||
hypothesis testing with equality (H_1: mu == 0, H_2: mu < 0, H_3: mu > 0). | ||
The algorithm uses an accurate and efficient deterministic algorithm to compute the | ||
correction factors | ||
E_{\hat{\theta}}^{H_0}[B_{i0}^N (X(l))] | ||
(see [1] section 2.4.2) | ||
Parameters | ||
---------- | ||
mu0 : double, default=0 | ||
The split point for the hypotheses | ||
with_equal: bool, default=True | ||
Whether to include an equals hypothesis | ||
Examples | ||
-------- | ||
>>> from bbai.stat import NormalMeanHypothesis | ||
>>> y = np.random.normal(size=9) | ||
>>> res = NormalMeanHypothesis().test(y) | ||
>>> print(res.left) # posterior probability for the hypothesis mu < 0 | ||
>>> print(res.equal) # posterior probability for the hypothesis mu == 0 | ||
>>> print(res.right) # posterior probability for the hypothesis mu > 0 | ||
""" | ||
def __init__(self, mu0 = 0, with_equal=True): | ||
self.mu0_ = mu0 | ||
self.with_equal_ = with_equal | ||
|
||
def test(self, data): | ||
n = len(data) | ||
mean = np.mean(data) | ||
std = np.std(data) | ||
z = (mean - self.mu0_) / std | ||
if self.with_equal_: | ||
return _NormalMeanHypothesisResult3(z, n, std) | ||
else: | ||
return _NormalMeanHypothesisResult2(z, n) | ||
|
||
def test_t(self, t, n): | ||
z = t / np.sqrt(n - 1) | ||
if self.with_equal_: | ||
return _NormalMeanHypothesisResult3(z, n) | ||
else: | ||
return _NormalMeanHypothesisResult2(z, n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
import numpy as np | ||
from bbai.stat import NormalMeanHypothesis | ||
|
||
def test_normal_mean_hypothesis(): | ||
hy = NormalMeanHypothesis() | ||
data = np.array([1, 2, 3]) | ||
|
||
t1 = hy.test(data) | ||
assert t1.left + t1.equal + t1.right == 1.0 | ||
assert t1.left < t1.right | ||
assert t1.equal < t1.right | ||
|
||
t2 = hy.test(-data) | ||
assert t2.left == pytest.approx(t1.right) | ||
assert t2.equal == pytest.approx(t1.equal) | ||
|
||
t3 = hy.test(list(data) + [4]) | ||
assert t3.right > t1.right | ||
|
||
t4 = NormalMeanHypothesis(mu0=1.23).test(data + 1.23) | ||
assert t4.left == pytest.approx(t1.left) | ||
assert t4.equal == pytest.approx(t1.equal) | ||
assert t4.right == pytest.approx(t1.right) | ||
|
||
def test_normal_mean_two_tailed_hypothesis(): | ||
h1 = NormalMeanHypothesis() | ||
h1p = NormalMeanHypothesis(with_equal=False) | ||
data = np.array([1, 2, 3]) | ||
|
||
t1 = h1.test(data) | ||
t1p = h1p.test(data) | ||
assert t1.factor_left == t1p.factor_left | ||
assert t1.factor_right == t1p.factor_right | ||
assert t1p.left == pytest.approx(1 - t1p.right) | ||
assert t1p.left == pytest.approx(t1.factor_left / (t1.factor_left + t1.factor_right)) | ||
|
||
if __name__ == "__main__": | ||
raise SystemExit(pytest.main([__file__])) |
Binary file not shown.