from functools import lru_cache
import numpy as np
from scipy import stats
from .utils import check_args, import_plt, call_shortcut
__all__ = ["CorrelatedTTest", "two_on_single"]
[docs]class Posterior:
"""
The posterior distribution of differences on a single data set.
Args:
mean (float): the mean difference
var (float): the variance
df (float): degrees of freedom
rope (float): rope (default: 0)
meanx (float): mean score of the first classifier; shown in a plot
meany (float): mean score of the second classifier; shown in a plot
names (tuple of str): names of classifiers; shown in a plot
nsamples (int): the number of samples; used only in property `sample`,
not in computation of probabilities or plotting (default: 50000)
"""
def __init__(self, mean, var, df, rope=0, meanx=None, meany=None,
*, names=None, nsamples=50000):
self.meanx = meanx
self.meany = meany
self.mean = mean
self.var = var
self.df = df
self.rope = rope
self.names = names
self.nsamples = nsamples
@property
@lru_cache(1)
def sample(self):
"""
A sample of differences as 1-dimensional array.
Like posteriors for comparison on multiple data sets, an instance of
this class will always return the same sample.
This sample is not used by other methods.
"""
if self.var == 0:
return np.full((self.nsamples, ), self.mean)
return self.mean + \
np.sqrt(self.var) * np.random.standard_t(self.df, self.nsamples)
[docs] def probs(self):
"""
Compute and return probabilities
Probabilities are not computed from a sample posterior but
from cumulative Student distribution.
Returns:
`(p_left, p_rope, p_right)` if `rope > 0`;
otherwise `(p_left, p_right)`.
"""
t_parameters = self.df, self.mean, np.sqrt(self.var)
if self.rope == 0:
if self.var == 0:
pr = (self.mean > 0) + 0.5 * (self.mean == 0)
else:
pr = 1 - stats.t.cdf(0, *t_parameters)
return 1 - pr, pr
else:
if self.var == 0:
pl = float(self.mean < -self.rope)
pr = float(self.mean > self.rope)
else:
pl = stats.t.cdf(-self.rope, *t_parameters)
pr = 1 - stats.t.cdf(self.rope, *t_parameters)
return pl, 1 - pl - pr, pr
[docs] def plot(self, names=None):
"""
Plot the posterior Student distribution as a histogram.
Args:
names (tuple of str): names of classifiers
Returns:
matplotlib figure
"""
plt = import_plt()
names = names or self.names or ("C1", "C2")
fig, ax = plt.subplots()
ax.grid(True)
label = "difference"
if self.meanx is not None and self.meany is not None:
label += " ({}: {:.3f}, {}: {:.3f})".format(
names[0], self.meanx, names[1], self.meany)
ax.set_xlabel(label)
ax.get_yaxis().set_ticklabels([])
ax.axvline(x=-self.rope, color="#ffad2f", linewidth=2, label="rope")
ax.axvline(x=self.rope, color="#ffad2f", linewidth=2)
targs = (self.df, self.mean, np.sqrt(self.var))
xs = np.linspace(min(stats.t.ppf(0.005, *targs), -1.05 * self.rope),
max(stats.t.ppf(0.995, *targs), 1.05 * self.rope),
100)
ys = stats.t.pdf(xs, *targs)
ax.plot(xs, ys, color="#2f56e0", linewidth=2, label="pdf")
ax.fill_between(xs, ys, np.zeros(100), color="#34ccff")
ax.legend()
return fig
[docs]def two_on_single(x, y, rope=0, runs=1, *, names=None, plot=False):
"""
Compute probabilities using a Bayesian correlated t-test and,
optionally, draw a histogram.
The test assumes that the classifiers were evaluated using cross
validation. Argument `runs` gives the number of repetitions of
cross-validation.
For more details, see :obj:`CorrelatedTTest`
Args:
x (np.array): a vector of scores for the first model
y (np.array): a vector of scores for the second model
rope (float): the width of the region of practical equivalence (default: 0)
runs (int): the number of repetitions of cross validation (default: 1)
nsamples (int): the number of samples (default: 50000)
plot (bool): if `True`, the function also return a histogram (default: False)
names (tuple of str): names of classifiers (ignored if `plot` is `False`)
Returns:
`(p_left, p_rope, p_right)` if `rope > 0`; otherwise `(p_left, p_right)`.
If `plot=True`, the function also returns a matplotlib figure,
that is, `((p_left, p_rope, p_right), fig)`
"""
return call_shortcut(CorrelatedTTest, x, y, rope,
plot=plot, names=names, runs=runs)