Rescorla-Wagner Models
This module incorporates includes various implementations of the Rescorla-Wagner model.
rescorla_wagner
Functions:
-
asymmetric_volatile_rescorla_wagner_update–Updates the estimated value of a state or action using a variant
-
complement_counterfactual–Counterfactual function that sets the value of unchosen actions to the
Attributes:
-
asymmetric_rescorla_wagner_update(Tuple[ArrayLike, ArrayLike]) –Updates the estimated value of a state or action using the Asymmetric
-
asymmetric_rescorla_wagner_update_choice(Array) –Updates the value estimate using the asymmetric Rescorla-Wagner
-
asymmetric_rescorla_wagner_update_choice_sticky(Array) –Updates the value estimate using the asymmetric Rescorla-Wagner
-
asymmetric_volatile_dynamic_rescorla_wagner_update_choice(Array) –Updates the value estimate using a variant of the Rescorla-Wagner
-
asymmetric_volatile_rescorla_wagner_single_value_update_choice(Array) –Updates the value estimate using the asymmetric volatile dynamic
asymmetric_rescorla_wagner_update
module-attribute
asymmetric_rescorla_wagner_update: Tuple[ArrayLike, ArrayLike] = jit(asymmetric_rescorla_wagner_update, static_argnames='counterfactual_value')
Updates the estimated value of a state or action using the Asymmetric Rescorla-Wagner learning rule.
The function calculates the prediction error as the difference between the actual outcome and the current estimated value. It then updates the estimated value based on the prediction error and the learning rate, which is determined by whether the prediction error is positive or negative.
Value estimates are only updated for chosen actions. For unchosen actions, the prediction error is set to 0.
Counterfactual updating can be used to set the value of unchosen actions according to a function of the value of chosen actions. This can be useful in cases where the value of unchosen actions should be set to a specific value, such as the negative of the value of chosen actions. By default this function sets the value of unchosen actions to the complement of the value of chosen actions:
counterfactual_value: callable = lambda reward, chosen: jnp.where(
chosen == 1,
0.0,
1.0 - jnp.sum(reward * jnp.asarray(chosen == 1, dtype=reward.dtype))
)
Parameters:
-
(valueArrayLike) –The current estimated value of a state or action.
-
(alpha_pArrayLike) –The learning rate used when the prediction error is positive.
-
(alpha_nArrayLike) –The learning rate used when the prediction error is negative.
-
(counterfactual_valuecallable]) –The value to use for unchosen actions. This should be provided as a callable function that returns a value. This will have no effect if
update_all_optionsis set to False. The function takes as input the values ofoutcomeandchosen(i.e., the two elements of theoutcome_chosenargument). By default, this assumes that outcomes are binary and sets the value of unchosen actions to the complement of the value of chosen actions. -
(update_all_optionsbool) –Whether to update the value estimates for all options, regardless of whether they were chosen. Defaults to False.
Returns:
-
Tuple[ArrayLike, ArrayLike]–Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]: The updated value and the prediction error.
asymmetric_rescorla_wagner_update_choice
module-attribute
asymmetric_rescorla_wagner_update_choice: Array = jit(asymmetric_rescorla_wagner_update_choice, static_argnums=(5, 6))
Updates the value estimate using the asymmetric Rescorla-Wagner algorithm, and chooses an option based on the softmax function.
See asymmetric_rescorla_wagner_update for details on the learning
rule.
Parameters:
-
(valueArrayLike) –The current value estimate.
-
(outcome_keyTuple[ArrayLike, PRNGKey]) –A tuple containing the outcome and the PRNG key.
-
(alpha_pfloat) –The learning rate for positive outcomes.
-
(alpha_nfloat) –The learning rate for negative outcomes.
-
(temperaturefloat) –The temperature parameter for softmax function.
-
(n_actionsint) –The number of actions to choose from.
-
(counterfactual_valuecallable]) –The value to use for unchosen actions. This should be provided as a callable function that returns a value. This will have no effect if
update_all_optionsis set to False. The function takes as input the values ofoutcomeandchosen(i.e., the two elements of theoutcome_chosenargument). By default, this assumes that outcomes are binary and sets the value of unchosen actions to the complement of the value of chosen actions. -
(update_all_optionsbool) –Whether to update the value estimates for all options, regardless of whether they were
Returns:
-
Array–Tuple[jax.Array, Tuple[jax.Array, jax.Array, int, jax.Array]]: - updated_value (jax.Array): The updated value estimate. - output_tuple (Tuple[jax.Array, jax.Array, int, jax.Array]): - value (jax.Array): The original value estimate. - choice_p (jax.Array): The choice probabilities. - choice (int): The chosen action. - choice_array (jax.Array): The chosen action in one-hot format.
asymmetric_rescorla_wagner_update_choice_sticky
module-attribute
asymmetric_rescorla_wagner_update_choice_sticky: Array = jit(asymmetric_rescorla_wagner_update_choice_sticky, static_argnums=(6, 7))
Updates the value estimate using the asymmetric Rescorla-Wagner algorithm, and chooses an option based on the softmax function.
Incorporates additional choice stickiness parameter, such that the probability of choosing the same option as the previous trial is increased (or decreased if the value is negative).
See asymmetric_rescorla_wagner_update for details on the learning
rule.
Parameters:
-
(value_choiceTuple[ArrayLike, ArrayLike]) –A tuple containing the current value estimate and the previous choice. The previous choice should be a one-hot encoded array of shape (n_actions,) where 1 indicates the chosen action and 0 indicates the unchosen actions.
-
(outcome_keyTuple[ArrayLike, PRNGKey]) –A tuple containing the outcome and the PRNG key.
-
(alpha_pfloat) –The learning rate for positive outcomes.
-
(alpha_nfloat) –The learning rate for negative outcomes.
-
(temperaturefloat) –The temperature parameter for softmax function.
-
(stickinessfloat) –The stickiness parameter for softmax function.
-
(n_actionsint) –The number of actions to choose from.
-
(counterfactual_valuecallable]) –The value to use for unchosen actions. This should be provided as a callable function that returns a value. This will have no effect if
update_all_optionsis set to False. The function takes as input the values ofoutcomeandchosen(i.e., the two elements of theoutcome_chosenargument). Defaults tolambda x, y: (1 - x) * (1 - y), which assumes outcomes are binary (0 or 1), and sets the value of unchosen actions to complement the value of chosen actions (i.e., a chosen value of 1 will set the unchosen value to 0 and vice versa). -
(update_all_optionsbool) –Whether to update the value estimates for all options, regardless of whether they were
Returns:
-
Array–Tuple[jax.Array, Tuple[jax.Array, jax.Array, int, jax.Array]]: - updated_value (jax.Array): The updated value estimate. - output_tuple (Tuple[jax.Array, jax.Array, int, jax.Array]): - value (jax.Array): The original value estimate. - choice_p (jax.Array): The choice probabilities. - choice (int): The chosen action. - choice_array (jax.Array): The chosen action in one-hot format.
asymmetric_volatile_dynamic_rescorla_wagner_update_choice
module-attribute
asymmetric_volatile_dynamic_rescorla_wagner_update_choice: Array = jit(asymmetric_volatile_dynamic_rescorla_wagner_update_choice, static_argnums=(7,))
Updates the value estimate using a variant of the Rescorla-Wagner learning rule that adjusts learning rate based on volatility and prediction error sign, and chooses an option based on the softmax function.
Note that learning rates for this function are transformed using a sigmoid function to ensure they are between 0 and 1. The raw parameter values supplied to the function must therefore be unbounded.
Parameters:
-
(valueArrayLike) –The current value estimate.
-
(alpha_basefloat) –The base learning rate.
-
(alpha_volatilityfloat) –The learning rate adjustment for volatile outcomes.
-
(alpha_pos_negfloat) –The learning rate adjustment for positive and negative prediction errors.
-
(alpha_interactionfloat) –The learning rate adjustment for the interaction between volatility and prediction error sign.
-
(temperaturefloat) –The temperature parameter for softmax function.
-
(n_actionsint) –The number of actions to choose from.
Returns:
-
Array–Tuple[jax.Array, Tuple[jax.typing.ArrayLike, jax.Array, int, jax.Array]]: - updated_value (jax.Array): The updated value estimate. - output_tuple (Tuple[jax.Array, jax.Array, int, jax.Array]): - value (jax.Array): The original value estimate. - choice_p (jax.Array): The choice probabilities. - choice (int): The chosen action. - choice_array (jax.Array): The chosen action in one-hot format.
asymmetric_volatile_rescorla_wagner_single_value_update_choice
module-attribute
asymmetric_volatile_rescorla_wagner_single_value_update_choice: Array = jit(asymmetric_volatile_rescorla_wagner_single_value_update_choice)
Updates the value estimate using the asymmetric volatile dynamic Rescorla-Wagner algorithm, and chooses an option based on the softmax function.
This version of the function is designed for cases where the a single value is being learnt, and this value is used to determine which of two options to choose. In practice, the value of option 1 is learnt, and the value of option 2 is set to 1 - value. This is appropriate for cases where the value of one option is the complement of the other.
Note that learning rates for this function are transformed using a sigmoid function to ensure they are between 0 and 1. The raw parameter values supplied to the function must therefore be unbounded.
Parameters:
-
(valueArrayLike) –The current value estimate.
-
(alpha_basefloat) –The base learning rate.
-
(alpha_volatilityfloat) –The learning rate adjustment for volatile outcomes.
-
(alpha_pos_negfloat) –The learning rate adjustment for positive and negative prediction errors.
-
(alpha_interactionfloat) –The learning rate adjustment for the interaction between volatility and prediction error sign.
-
(temperaturefloat) –The temperature parameter for softmax function.
-
(n_actionsint) –The number of actions to choose from.
asymmetric_volatile_rescorla_wagner_update
asymmetric_volatile_rescorla_wagner_update(value: ArrayLike, outcome_chosen_volatility: Tuple[ArrayLike, ArrayLike, ArrayLike], alpha_base: float, alpha_volatility: float, alpha_pos_neg: float, alpha_interaction: float) -> Tuple[ArrayLike, Tuple[ArrayLike, ArrayLike]]
Updates the estimated value of a state or action using a variant of the Rescorla-Wagner learning rule that incorporates adjusting the learning rate based on both volatility and prediction error sign.
Note that learning rates for this function are transformed using a sigmoid function to ensure they are between 0 and 1. The raw parameter values supplied to the function must therefore be unbounded.
Parameters:
-
(valueArrayLike) –The current estimated value of a state or action.
-
(alpha_basefloat) –The base learning rate.
-
(alpha_volatilityfloat) –The learning rate adjustment for volatile outcomes.
-
(alpha_pos_negfloat) –The learning rate adjustment for positive and negative prediction errors.
-
(alpha_interactionfloat) –The learning rate adjustment for the interaction between volatility and prediction error sign.
Returns:
-
Tuple[ArrayLike, Tuple[ArrayLike, ArrayLike]]–Tuple[jax.typing.ArrayLike, Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]]: - updated_value (jax.typing.ArrayLike): The updated value estimate. - output_tuple (Tuple[jax.typing.ArrayLike, jax.typing.ArrayLike]): - value (jax.typing.ArrayLike): The original value estimate. - prediction_error (jax.typing.ArrayLike): The prediction error.
Source code in behavioural_modelling/learning/rescorla_wagner.py
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 | |
complement_counterfactual
Counterfactual function that sets the value of unchosen actions to the complement of the value of chosen actions.
Parameters:
-
(rewardArrayLike) –The reward received.
-
(chosenArrayLike) –A binary array indicating which action(s) were chosen.
Returns:
-
–
jax.typing.ArrayLike: The counterfactual value.
Source code in behavioural_modelling/learning/rescorla_wagner.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | |