Skip to content

Utilities

utils

Functions:

choice_from_action_p

choice_from_action_p(key: PRNGKey, probs: ArrayLike, lapse: float = 0.0) -> int

Choose an action from a set of action probabilities. Can take probabilities in the form of an n-dimensional array, where the last dimension is the number of actions.

Noise is added to the choice, with probability lapse. This means that on "lapse" trials, the subject will choose an action uniformly at random.

Parameters:

  • key

    (int) –

    Jax random key

  • probs

    (ndarray) –

    N-dimension array of action probabilities, of shape (..., n_actions)

  • lapse

    (float, default: 0.0 ) –

    Probability of lapse. Defaults to 0.0.

Returns: int: Chosen action

Source code in behavioural_modelling/utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@jax.jit
def choice_from_action_p(key: jax.random.PRNGKey, probs: ArrayLike, lapse: float = 0.0) -> int:
    """
    Choose an action from a set of action probabilities. Can take probabilities
    in the form of an n-dimensional array, where the last dimension is the
    number of actions.

    Noise is added to the choice, with probability `lapse`. This means that
    on "lapse" trials, the subject will choose an action uniformly at random.

    Args:
        key (int): Jax random key
        probs (np.ndarray): N-dimension array of action probabilities, of shape (..., n_actions)
        lapse (float, optional): Probability of lapse. Defaults to 0.0.
    Returns:
        int: Chosen action
    """

    # Reshape probs
    probs_reshaped = probs.reshape((-1, probs.shape[-1]))

    # Split keys so that we have one for each index in the first
    # dimension of probs
    keys = jax.random.split(key, probs_reshaped.shape[0])

    # Get choices
    choices = choice_func_vmap(keys, probs_reshaped, lapse)

    # Reshape choices
    choices = choices.reshape(probs.shape[:-1])

    return choices

choice_from_action_p_single

choice_from_action_p_single(key: PRNGKey, probs: ArrayLike, lapse: float = 0.0) -> int

Choose an action from a set of action probabilities for a single choice.

Parameters:

  • key

    (PRNGKey) –

    Jax random key

  • probs

    (ArrayLike) –

    1D array of action probabilities, of shape (n_actions)

  • lapse

    (float, default: 0.0 ) –

    Lapse parameter. On lapse trials, a random action is selected. Defaults to 0.0.

Returns:

  • int ( int ) –

    Chosen action

Source code in behavioural_modelling/utils.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@jax.jit
def choice_from_action_p_single(
    key: jax.random.PRNGKey, probs: ArrayLike, lapse: float = 0.0
) -> int:
    """
    Choose an action from a set of action probabilities for a single choice.

    Args:
        key (jax.random.PRNGKey): Jax random key
        probs (ArrayLike): 1D array of action probabilities, of shape (n_actions)
        lapse (float, optional): Lapse parameter. On lapse trials, a random action is selected. Defaults to 0.0.

    Returns:
        int: Chosen action
    """

    # Get number of possible actions
    n_actions = len(probs)

    # Deal with zero values etc
    probs = probs + 1e-6 / jnp.sum(probs)

    # Add noise
    noise = jax.random.uniform(key) < lapse

    # Choose action
    choice = (1 - noise) * jax.random.choice(
        key, jnp.arange(n_actions, dtype=int), p=probs
    ) + noise * jax.random.randint(key, shape=(), minval=0, maxval=n_actions)

    return choice