Skip to content

Rescorla-Wagner Models

This module incorporates includes various implementations of the Rescorla-Wagner model.

rescorla_wagner

Functions:

Attributes:

asymmetric_rescorla_wagner_update module-attribute

asymmetric_rescorla_wagner_update: Tuple[ArrayLike, ArrayLike] = jit(asymmetric_rescorla_wagner_update, static_argnames='counterfactual_value')

Updates the estimated value of a state or action using the Asymmetric Rescorla-Wagner learning rule.

The function calculates the prediction error as the difference between the actual outcome and the current estimated value. It then updates the estimated value based on the prediction error and the learning rate, which is determined by whether the prediction error is positive or negative.

Value estimates are only updated for chosen actions. For unchosen actions, the prediction error is set to 0.

Counterfactual updating can be used to set the value of unchosen actions according to a function of the value of chosen actions. This can be useful in cases where the value of unchosen actions should be set to a specific value, such as the negative of the value of chosen actions. By default this function sets the value of unchosen actions to the complement of the value of chosen actions:

counterfactual_value: callable = lambda reward, chosen: jnp.where(
    chosen == 1, 
    0.0, 
    1.0 - jnp.sum(reward * jnp.asarray(chosen == 1, dtype=reward.dtype))
)

Parameters:

  • value

    (ArrayLike) –

    The current estimated value of a state or action.

  • alpha_p

    (ArrayLike) –

    The learning rate used when the prediction error is positive.

  • alpha_n

    (ArrayLike) –

    The learning rate used when the prediction error is negative.

  • counterfactual_value

    (callable]) –

    The value to use for unchosen actions. This should be provided as a callable function that returns a value. This will have no effect if update_all_options is set to False. The function takes as input the values of outcome and chosen (i.e., the two elements of the outcome_chosen argument). By default, this assumes that outcomes are binary and sets the value of unchosen actions to the complement of the value of chosen actions.

  • update_all_options

    (bool) –

    Whether to update the value estimates for all options, regardless of whether they were chosen. Defaults to False.

Returns:

  • Tuple[ArrayLike, ArrayLike]

    Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]: The updated value and the prediction error.

asymmetric_rescorla_wagner_update_choice module-attribute

asymmetric_rescorla_wagner_update_choice: Array = jit(asymmetric_rescorla_wagner_update_choice, static_argnums=(5, 6))

Updates the value estimate using the asymmetric Rescorla-Wagner algorithm, and chooses an option based on the softmax function.

See asymmetric_rescorla_wagner_update for details on the learning rule.

Parameters:

  • value

    (ArrayLike) –

    The current value estimate.

  • outcome_key

    (Tuple[ArrayLike, PRNGKey]) –

    A tuple containing the outcome and the PRNG key.

  • alpha_p

    (float) –

    The learning rate for positive outcomes.

  • alpha_n

    (float) –

    The learning rate for negative outcomes.

  • temperature

    (float) –

    The temperature parameter for softmax function.

  • n_actions

    (int) –

    The number of actions to choose from.

  • counterfactual_value

    (callable]) –

    The value to use for unchosen actions. This should be provided as a callable function that returns a value. This will have no effect if update_all_options is set to False. The function takes as input the values of outcome and chosen (i.e., the two elements of the outcome_chosen argument). By default, this assumes that outcomes are binary and sets the value of unchosen actions to the complement of the value of chosen actions.

  • update_all_options

    (bool) –

    Whether to update the value estimates for all options, regardless of whether they were

Returns:

  • Array

    Tuple[jax.Array, Tuple[jax.Array, jax.Array, int, jax.Array]]: - updated_value (jax.Array): The updated value estimate. - output_tuple (Tuple[jax.Array, jax.Array, int, jax.Array]): - value (jax.Array): The original value estimate. - choice_p (jax.Array): The choice probabilities. - choice (int): The chosen action. - choice_array (jax.Array): The chosen action in one-hot format.

asymmetric_rescorla_wagner_update_choice_sticky module-attribute

asymmetric_rescorla_wagner_update_choice_sticky: Array = jit(asymmetric_rescorla_wagner_update_choice_sticky, static_argnums=(6, 7))

Updates the value estimate using the asymmetric Rescorla-Wagner algorithm, and chooses an option based on the softmax function.

Incorporates additional choice stickiness parameter, such that the probability of choosing the same option as the previous trial is increased (or decreased if the value is negative).

See asymmetric_rescorla_wagner_update for details on the learning rule.

Parameters:

  • value_choice

    (Tuple[ArrayLike, ArrayLike]) –

    A tuple containing the current value estimate and the previous choice. The previous choice should be a one-hot encoded array of shape (n_actions,) where 1 indicates the chosen action and 0 indicates the unchosen actions.

  • outcome_key

    (Tuple[ArrayLike, PRNGKey]) –

    A tuple containing the outcome and the PRNG key.

  • alpha_p

    (float) –

    The learning rate for positive outcomes.

  • alpha_n

    (float) –

    The learning rate for negative outcomes.

  • temperature

    (float) –

    The temperature parameter for softmax function.

  • stickiness

    (float) –

    The stickiness parameter for softmax function.

  • n_actions

    (int) –

    The number of actions to choose from.

  • counterfactual_value

    (callable]) –

    The value to use for unchosen actions. This should be provided as a callable function that returns a value. This will have no effect if update_all_options is set to False. The function takes as input the values of outcome and chosen (i.e., the two elements of the outcome_chosen argument). Defaults to lambda x, y: (1 - x) * (1 - y), which assumes outcomes are binary (0 or 1), and sets the value of unchosen actions to complement the value of chosen actions (i.e., a chosen value of 1 will set the unchosen value to 0 and vice versa).

  • update_all_options

    (bool) –

    Whether to update the value estimates for all options, regardless of whether they were

Returns:

  • Array

    Tuple[jax.Array, Tuple[jax.Array, jax.Array, int, jax.Array]]: - updated_value (jax.Array): The updated value estimate. - output_tuple (Tuple[jax.Array, jax.Array, int, jax.Array]): - value (jax.Array): The original value estimate. - choice_p (jax.Array): The choice probabilities. - choice (int): The chosen action. - choice_array (jax.Array): The chosen action in one-hot format.

asymmetric_volatile_dynamic_rescorla_wagner_update_choice module-attribute

asymmetric_volatile_dynamic_rescorla_wagner_update_choice: Array = jit(asymmetric_volatile_dynamic_rescorla_wagner_update_choice, static_argnums=(7,))

Updates the value estimate using a variant of the Rescorla-Wagner learning rule that adjusts learning rate based on volatility and prediction error sign, and chooses an option based on the softmax function.

Note that learning rates for this function are transformed using a sigmoid function to ensure they are between 0 and 1. The raw parameter values supplied to the function must therefore be unbounded.

Parameters:

  • value

    (ArrayLike) –

    The current value estimate.

  • alpha_base

    (float) –

    The base learning rate.

  • alpha_volatility

    (float) –

    The learning rate adjustment for volatile outcomes.

  • alpha_pos_neg

    (float) –

    The learning rate adjustment for positive and negative prediction errors.

  • alpha_interaction

    (float) –

    The learning rate adjustment for the interaction between volatility and prediction error sign.

  • temperature

    (float) –

    The temperature parameter for softmax function.

  • n_actions

    (int) –

    The number of actions to choose from.

Returns:

  • Array

    Tuple[jax.Array, Tuple[jax.typing.ArrayLike, jax.Array, int, jax.Array]]: - updated_value (jax.Array): The updated value estimate. - output_tuple (Tuple[jax.Array, jax.Array, int, jax.Array]): - value (jax.Array): The original value estimate. - choice_p (jax.Array): The choice probabilities. - choice (int): The chosen action. - choice_array (jax.Array): The chosen action in one-hot format.

asymmetric_volatile_rescorla_wagner_single_value_update_choice module-attribute

asymmetric_volatile_rescorla_wagner_single_value_update_choice: Array = jit(asymmetric_volatile_rescorla_wagner_single_value_update_choice)

Updates the value estimate using the asymmetric volatile dynamic Rescorla-Wagner algorithm, and chooses an option based on the softmax function.

This version of the function is designed for cases where the a single value is being learnt, and this value is used to determine which of two options to choose. In practice, the value of option 1 is learnt, and the value of option 2 is set to 1 - value. This is appropriate for cases where the value of one option is the complement of the other.

Note that learning rates for this function are transformed using a sigmoid function to ensure they are between 0 and 1. The raw parameter values supplied to the function must therefore be unbounded.

Parameters:

  • value

    (ArrayLike) –

    The current value estimate.

  • alpha_base

    (float) –

    The base learning rate.

  • alpha_volatility

    (float) –

    The learning rate adjustment for volatile outcomes.

  • alpha_pos_neg

    (float) –

    The learning rate adjustment for positive and negative prediction errors.

  • alpha_interaction

    (float) –

    The learning rate adjustment for the interaction between volatility and prediction error sign.

  • temperature

    (float) –

    The temperature parameter for softmax function.

  • n_actions

    (int) –

    The number of actions to choose from.

asymmetric_volatile_rescorla_wagner_update

asymmetric_volatile_rescorla_wagner_update(value: ArrayLike, outcome_chosen_volatility: Tuple[ArrayLike, ArrayLike, ArrayLike], alpha_base: float, alpha_volatility: float, alpha_pos_neg: float, alpha_interaction: float) -> Tuple[ArrayLike, Tuple[ArrayLike, ArrayLike]]

Updates the estimated value of a state or action using a variant of the Rescorla-Wagner learning rule that incorporates adjusting the learning rate based on both volatility and prediction error sign.

Note that learning rates for this function are transformed using a sigmoid function to ensure they are between 0 and 1. The raw parameter values supplied to the function must therefore be unbounded.

Parameters:

  • value

    (ArrayLike) –

    The current estimated value of a state or action.

  • alpha_base

    (float) –

    The base learning rate.

  • alpha_volatility

    (float) –

    The learning rate adjustment for volatile outcomes.

  • alpha_pos_neg

    (float) –

    The learning rate adjustment for positive and negative prediction errors.

  • alpha_interaction

    (float) –

    The learning rate adjustment for the interaction between volatility and prediction error sign.

Returns:

  • Tuple[ArrayLike, Tuple[ArrayLike, ArrayLike]]

    Tuple[jax.typing.ArrayLike, Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]]: - updated_value (jax.typing.ArrayLike): The updated value estimate. - output_tuple (Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]): - value (jax.typing.ArrayLike): The original value estimate. - prediction_error (jax.typing.ArrayLike): The prediction error.

Source code in behavioural_modelling/learning/rescorla_wagner.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
@jax.jit
def asymmetric_volatile_rescorla_wagner_update(
    value: jax.typing.ArrayLike,
    outcome_chosen_volatility: Tuple[
        jax.typing.ArrayLike,
        jax.typing.ArrayLike,
        jax.typing.ArrayLike,
    ],
    alpha_base: float,
    alpha_volatility: float,
    alpha_pos_neg: float,
    alpha_interaction: float,
) -> Tuple[
    jax.typing.ArrayLike, Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]
]:
    """
    Updates the estimated value of a state or action using a variant
    of the Rescorla-Wagner learning rule that incorporates adjusting
    the learning rate based on both volatility and prediction error sign.

    Note that learning rates for this function are transformed using a
    sigmoid function to ensure they are between 0 and 1. The raw
    parameter values supplied to the function must therefore be
    unbounded.

    Args:
        value (jax.typing.ArrayLike): The current estimated value of a
            state or action.
        outcome_chosen_volatility (Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike,
            jax.typing.ArrayLike]): A tuple containing the outcome, the chosen
            action, and the volatility indicator. The outcome is a float or an
            array (e.g., for a single outcome or multiple outcomes). The chosen
            action is a one-hot encoded array of shape (n_actions,) where 1
            indicates the chosen action and 0 indicates the unchosen actions.
            The volatility indicator is a binary value that indicates whether
            the outcome is volatile (1) or stable (0).
        alpha_base (float): The base learning rate.
        alpha_volatility (float): The learning rate adjustment for volatile
            outcomes.
        alpha_pos_neg (float): The learning rate adjustment for positive and
            negative prediction errors.
        alpha_interaction (float): The learning rate adjustment for the
            interaction between volatility and prediction error sign.

    Returns:
        Tuple[jax.typing.ArrayLike, Tuple[jax.typing.ArrayLike,
            jax.typing.ArrayLike]]:
            - updated_value (jax.typing.ArrayLike): The updated value estimate.
            - output_tuple (Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]):
                - value (jax.typing.ArrayLike): The original value estimate.
                - prediction_error (jax.typing.ArrayLike): The prediction
                  error.
    """

    # Unpack the outcome and the chosen action
    outcome, chosen, volatility_indicator = outcome_chosen_volatility

    # Calculate the prediction error
    prediction_error = outcome - value

    # Set prediction error to 0 for unchosen actions
    prediction_error = prediction_error * chosen

    # Determine whether the error is positive (1) or negative (-1)
    PE_sign = jnp.sign(prediction_error)

    # Compute interaction term (volatility_indicator * error_sign)
    interaction_term = volatility_indicator * PE_sign

    # Compute the dynamic learning rate using base, volatility, and interaction terms
    # Remember we can't use if else statements here because JAX doesn't tolerate them
    # Use adjusted learning rates for positive/negative prediction errors
    alpha_t = jax.nn.sigmoid(
        alpha_base
        + alpha_volatility * volatility_indicator
        + alpha_pos_neg * PE_sign
        + alpha_interaction * interaction_term
    )

    # Update the value
    updated_value = value + alpha_t * prediction_error

    return updated_value, (value, prediction_error)

complement_counterfactual

complement_counterfactual(reward, chosen)

Counterfactual function that sets the value of unchosen actions to the complement of the value of chosen actions.

Parameters:

  • reward

    (ArrayLike) –

    The reward received.

  • chosen

    (ArrayLike) –

    A binary array indicating which action(s) were chosen.

Returns:

  • jax.typing.ArrayLike: The counterfactual value.

Source code in behavioural_modelling/learning/rescorla_wagner.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def complement_counterfactual(reward, chosen):
    """
    Counterfactual function that sets the value of unchosen actions to the
    complement of the value of chosen actions.

    Args:
        reward (jax.typing.ArrayLike): The reward received.
        chosen (jax.typing.ArrayLike): A binary array indicating which
            action(s) were chosen.

    Returns:
        jax.typing.ArrayLike: The counterfactual value.
    """
    return jnp.where(
        chosen == 1, 
        0.0, 
        1.0 - jnp.sum(reward * jnp.asarray(chosen == 1, dtype=reward.dtype))
    )