Skip to content

Beta Models

beta_models

Functions:

average_betas

average_betas(beta1_params: ArrayLike, beta2_params: ArrayLike, W1: float = 0.5, W2: float = 0.5) -> ndarray

Average two beta distributions, weighted by W.

Parameters:

  • beta1_params

    (ArrayLike) –

    Parameters of first beta distribution.

  • beta2_params

    (ArrayLike) –

    Parameters of second beta distribution.

Returns:

  • ndarray

    jnp.ndarray: New beta distribution parameters.

Source code in behavioural_modelling/learning/beta_models.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
@jax.jit
def average_betas(
    beta1_params: ArrayLike,
    beta2_params: ArrayLike,
    W1: float = 0.5,
    W2: float = 0.5,
) -> jnp.ndarray:
    """
    Average two beta distributions, weighted by W.

    Args:
        beta1_params (ArrayLike): Parameters of first beta distribution.
        beta2_params (ArrayLike): Parameters of second beta distribution.

    Returns:
        jnp.ndarray: New beta distribution parameters.
    """

    # Extract parameters
    a1 = beta1_params[..., 0]
    b1 = beta1_params[..., 1]
    a2 = beta2_params[..., 0]
    b2 = beta2_params[..., 1]

    # Calculate average
    a_new = (W1 * a1) + (W2 * a2)
    b_new = (W1 * b1) + (W2 * b2)

    # Return new parameters
    return jnp.stack([a_new, b_new], axis=-1)

beta_mean_var

beta_mean_var(beta_params: ArrayLike) -> Tuple[ArrayLike, ArrayLike]

Calculate mean and variance of a beta distribution.

Parameters:

  • beta_params

    (ArrayLike) –

    Parameters of the beta distribution. Of shape (n_options, 2),

Returns:

  • Tuple[ArrayLike, ArrayLike]

    tuple[ArrayLike, ArrayLike]: Mean and variance of the beta distribution.

Source code in behavioural_modelling/learning/beta_models.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@jax.jit
def beta_mean_var(beta_params: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
    """
    Calculate mean and variance of a beta distribution.

    Args:
        beta_params (ArrayLike): Parameters of the beta distribution. Of shape (n_options, 2),
        where the first dimension represents the number of options (each of which has its own
        beta distribution), and the second dimension represents the alpha and beta parameters
        of each beta distribution.

    Returns:
        tuple[ArrayLike, ArrayLike]: Mean and variance of the beta distribution.
    """
    a, b = beta_params[..., 0], beta_params[..., 1]
    mean = a / (a + b)
    var = (a * b) / ((a + b) ** 2 * (a + b + 1))
    return mean, var

generalised_beta_mean_var

generalised_beta_mean_var(alpha: float, beta: float, a: float, b: float) -> Tuple[float, float]

Calculate mean and variance of a generalised beta distribution.

Parameters:

  • alpha

    (float) –

    Alpha parameter of the beta distribution.

  • beta

    (float) –

    Beta parameter of the beta distribution.

  • a

    (float) –

    Lower bound of the beta distribution.

  • b

    (float) –

    Upper bound of the beta distribution.

Returns:

  • Tuple[float, float]

    tuple[float, float]: Mean and variance of the beta distribution.

Source code in behavioural_modelling/learning/beta_models.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@jax.jit
def generalised_beta_mean_var(
    alpha: float, beta: float, a: float, b: float
) -> Tuple[float, float]:
    """
    Calculate mean and variance of a generalised beta distribution.

    Args:
        alpha (float): Alpha parameter of the beta distribution.
        beta (float): Beta parameter of the beta distribution.
        a (float): Lower bound of the beta distribution.
        b (float): Upper bound of the beta distribution.

    Returns:
        tuple[float, float]: Mean and variance of the beta distribution.
    """
    mean = ((b - a) * alpha) / (alpha + beta) + a
    var = (alpha * beta * (b - a) ** 2) / ((alpha + beta) ** 2 * (alpha + beta + 1))
    return mean, var

leaky_beta_update

leaky_beta_update(estimate: ArrayLike, choices: ArrayLike, outcome: float, tau_p: float, tau_n: float, decay: float, update: int = 1, increment: int = 1) -> ndarray

Update estimates using the (asymmetric) leaky beta model.

This models represents the probability of the outcome associated with each option (e.g., bandits in a bandit task) as a beta distribution.

Values are updated according to the following equations:

\[ A_i^{t+1} = \lambda \cdot A_i^{t} + outcome_t \cdot \tau^{+} \]
\[ B_i^{t+1} = \lambda \cdot B_i^{t} + (1-outcome_t) \cdot \tau^{-} \]

This function also allows for updating to be turned off (i.e., the estimate is not updated at all) and for incrementing to be turned off (i.e., decay is applied, but the outcome is not registered).

Only chosen options incremented, but all options decay.

Parameters:

  • estimate

    (ArrayLike) –

    Alpha and beta estimates for this trial. Should be an array of shape (n, 2) where

  • choices

    (ArrayLike) –

    Choices made in this trial. Should have as many entries as there are options, with

  • outcomes

    (float) –

    Observed outcome for this trial.

  • tau_p

    (float) –

    Update rate for outcomes equal to 1.

  • tau_n

    (float) –

    Update rate for outcomes equal to 0.

  • decay

    (float) –

    Decay rate.

  • update

    (int, default: 1 ) –

    Whether to update the estimate. If 0, the estimate is not updated (i.e., no decay is

  • increment

    (int, default: 1 ) –

    Whether to increment the estimate. If 0, the estimate is not incremented but

Returns: jnp.ndarray: Updated value estimates for this trial, with one entry per bandit.

Source code in behavioural_modelling/learning/beta_models.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
@jax.jit
def leaky_beta_update(
    estimate: ArrayLike,
    choices: ArrayLike,
    outcome: float,
    tau_p: float,
    tau_n: float,
    decay: float,
    update: int = 1,
    increment: int = 1,
) -> jnp.ndarray:
    """
    Update estimates using the (asymmetric) leaky beta model. 

    This models represents the probability of the outcome associated with each option (e.g., bandits in a bandit task)
    as a beta distribution.

    Values are updated according to the following equations:

    $$
    A_i^{t+1} = \\lambda \\cdot A_i^{t} + outcome_t \\cdot \\tau^{+}
    $$

    $$
    B_i^{t+1} = \\lambda \\cdot B_i^{t} + (1-outcome_t) \\cdot \\tau^{-}
    $$

    This function also allows for updating to be turned off (i.e., the estimate is not updated at all) and for incrementing
    to be turned off (i.e., decay is applied, but the outcome is not registered).

    Only chosen options incremented, but all options decay.

    Args:
        estimate (ArrayLike): Alpha and beta estimates for this trial. Should be an array of shape (n, 2) where
        the first dimension represents the alpha and beta parameters of the beta distribution and the second
        dimension represents the number of option.
        choices (ArrayLike): Choices made in this trial. Should have as many entries as there are options, with
        zeros for non-chosen options and ones for chosen options (i.e., one-hot encoded).
        outcomes (float): Observed outcome for this trial.
        tau_p (float): Update rate for outcomes equal to 1.
        tau_n (float): Update rate for outcomes equal to 0.
        decay (float): Decay rate.
        update (int, optional): Whether to update the estimate. If 0, the estimate is not updated (i.e., no decay is
        applied, and the outcome of the trial does not affect the outcome). Defaults to 1.
        increment (int, optional): Whether to increment the estimate. If 0, the estimate is not incremented but
        decay is applied. Defaults to 1.
    Returns:
        jnp.ndarray: Updated value estimates for this trial, with one entry per bandit.
    """

    # For each parameter, we apply the decay to (previous value - 1) so that we are in effect
    # treating 1 as the baseline value. This is helpful becuase values of < 1 can produce
    # strange-looking distributions (e.g., with joint peaks at 0 and 1). Keeping values
    # > 1 ensures that the baseline distribution (ignoring any evidence we've observed)
    # is a flat distribution between 0 and 1. This also generally aids parameter recovery.

    # Make sure any outcomes > 1 are set to 1
    outcome = jnp.array(outcome > 0, int)

    # Update alpha
    update_1 = (
        1 + (decay * (estimate[:, 0] - 1)) + (tau_p * (choices * outcome) * increment)
    )
    estimate = estimate.at[:, 0].set(
        (update * update_1) + ((1 - update) * estimate[:, 0])
    )

    # Update beta
    update_2 = (
        1
        + (decay * (estimate[:, 1] - 1))
        + (tau_n * (choices * (1 - outcome)) * increment)
    )
    estimate = estimate.at[:, 1].set(
        (update * update_2) + ((1 - update) * estimate[:, 1])
    )

    return estimate

multiply_beta_by_scalar

multiply_beta_by_scalar(beta_params: ArrayLike, scalar: float) -> ndarray

Multiply a beta distribution by a scalar.

Parameters:

  • beta_params

    (ArrayLike) –

    Parameters of beta distribution. Of shape (n_options, 2),

  • scalar

    (float) –

    Scalar to multiply beta distribution by.

Returns:

  • ndarray

    jnp.ndarray: New beta distribution parameters, specified as [a, b].

Source code in behavioural_modelling/learning/beta_models.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@jax.jit
def multiply_beta_by_scalar(beta_params: ArrayLike, scalar: float) -> jnp.ndarray:
    """
    Multiply a beta distribution by a scalar.

    Args:
        beta_params (ArrayLike): Parameters of beta distribution. Of shape (n_options, 2),
        where the first dimension represents the number of options (each of which has its own
        beta distribution), and the second dimension represents the alpha and beta parameters
        of each beta distribution.
        scalar (float): Scalar to multiply beta distribution by.

    Returns:
        jnp.ndarray: New beta distribution parameters, specified as [a, b].
    """

    # Extract parameters
    a = beta_params[..., 0]
    b = beta_params[..., 1]

    # Calculate mean and variance
    mean, var = beta_mean_var(beta_params)

    # Scale mean and variance
    mean = mean * scalar
    var = var * scalar**2

    # Calculate new parameters
    a_new = mean * ((mean * (1 - mean)) / var - 1)
    b_new = (1 - mean) * ((mean * (1 - mean)) / var - 1)

    # Return new parameters
    return jnp.stack([a_new, b_new], axis=-1)

sum_betas

sum_betas(beta1_params: ArrayLike, beta2_params: ArrayLike) -> ndarray

Sum two beta distributions. This uses an approximation described in the following paper:

Pham, T.G., Turkkan, N., 1994. Reliability of a standby system with beta-distributed component lives. IEEE Transactions on Reliability 43, 71–75. https://doi.org/10.1109/24.285114

Where the first two moments of the summed distribution are calculated as follows:

\[ \mu = \mu_1 + \mu_2 \]
\[ \sigma^2 = \sigma_1^2 + \sigma_2^2 \]

We then calculate the parameters of the new beta distribution using the following equations:

\[ \alpha = \mu \left( \frac{\mu (1 - \mu)}{\sigma^2} - 1 \right) \]
\[ \beta = (1 - \mu) \left( \frac{\mu (1 - \mu)}{\sigma^2} - 1 \right) \]

This function assumes that the means of the two beta distributions sum to <=1. If this is not the case, the output will be invalid.

Parameters:

  • beta1_params

    (ArrayLike) –

    Parameters of the first beta distribution. Of shape (n_options, 2),

  • beta2_params

    (ArrayLike) –

    Parameters of second beta distribution.

Returns:

  • ndarray

    jnp.ndarray: New beta distribution parameters.

Source code in behavioural_modelling/learning/beta_models.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@jax.jit
def sum_betas(beta1_params: ArrayLike, beta2_params: ArrayLike) -> jnp.ndarray:
    """
    Sum two beta distributions. This uses an approximation described in the following paper:

    Pham, T.G., Turkkan, N., 1994. Reliability of a standby system with beta-distributed component lives.
    IEEE Transactions on Reliability 43, 71–75. https://doi.org/10.1109/24.285114

    Where the first two moments of the summed distribution are calculated as follows:

    $$
    \\mu = \\mu_1 + \\mu_2
    $$

    $$
    \\sigma^2 = \\sigma_1^2 + \\sigma_2^2
    $$

    We then calculate the parameters of the new beta distribution using the following equations:

    $$
    \\alpha = \\mu \\left( \\frac{\\mu (1 - \\mu)}{\\sigma^2} - 1 \\right)
    $$

    $$
    \\beta = (1 - \\mu) \\left( \\frac{\\mu (1 - \\mu)}{\\sigma^2} - 1 \\right)
    $$

    This function assumes that the means of the two beta distributions sum to <=1. If this is not the case,
    the output will be invalid.

    Args:
        beta1_params (ArrayLike): Parameters of the first beta distribution. Of shape (n_options, 2),
        where the first dimension represents the number of options (each of which has its own
        beta distribution), and the sec
        beta2_params (ArrayLike): Parameters of second beta distribution.

    Returns:
        jnp.ndarray: New beta distribution parameters.
    """

    # Extract parameters
    a1 = beta1_params[..., 0]
    b1 = beta1_params[..., 1]
    a2 = beta2_params[..., 0]
    b2 = beta2_params[..., 1]

    # Calculate means and variances
    mean1 = a1 / (a1 + b1)
    var1 = (a1 * b1) / ((a1 + b1) ** 2 * (a1 + b1 + 1))
    mean2 = a2 / (a2 + b2)
    var2 = (a2 * b2) / ((a2 + b2) ** 2 * (a2 + b2 + 1))

    # Sum means and variances
    mean_new = mean1 + mean2
    var_new = var1 + var2

    # Calculate new parameters
    a_new = mean_new * ((mean_new * (1 - mean_new)) / var_new - 1)
    b_new = (1 - mean_new) * ((mean_new * (1 - mean_new)) / var_new - 1)

    # Return new parameters
    return jnp.stack([a_new, b_new], axis=-1)