Value iteration example¶
Value iteration is a dynamic programming algorithm that computes the optimal value function and the optimal policy for a Markov decision process (MDP). It is a "model-based" algorithm, meaning that it requires knowledge of the transition probabilities and rewards of the MDP.
Imports¶
In [1]:
Copied!
# autorelaod
%load_ext autoreload
%autoreload 2
# autorelaod
%load_ext autoreload
%autoreload 2
In [ ]:
Copied!
import numpy as np
import jax.numpy as jnp
from behavioural_modelling.planning.dynamic_programming import (
solve_value_iteration,
)
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from behavioural_modelling.planning.dynamic_programming import (
solve_value_iteration,
)
import matplotlib.pyplot as plt
Create a function to define a simple grid world environment¶
In [ ]:
Copied!
# Simple Gridworld setup
def create_gridworld(
grid_size: int = 15,
discount: float = 0.9,
n_rewards: int = 1,
seed: int = 42,
):
# Initialise RNG
rng = np.random.RandomState(seed)
# Number of states and actions
n_states = grid_size**2
n_actions = 5 # up, down, left, right, stay
# Reward function
reward_function = np.zeros(n_states)
for _ in range(n_rewards):
reward_function[rng.randint(n_states)] = 1.0
# Features matrix: Identity (each state is a feature)
features = np.eye(n_states)
# SAS matrix (state, action, state transition probabilities)
sas = np.zeros((n_states, n_actions, n_states))
# Create transition probabilities for a simple gridworld
def state_to_coords(s, size):
return s // size, s % size
def coords_to_state(x, y, size):
return x * size + y
# Get the state transition probabilities
for s in range(n_states):
x, y = state_to_coords(s, grid_size)
# Up
if x > 0:
sas[s, 0, coords_to_state(x - 1, y, grid_size)] = 1.0
else:
sas[s, 0, s] = 1.0 # Bump into wall
# Down
if x < grid_size - 1:
sas[s, 1, coords_to_state(x + 1, y, grid_size)] = 1.0
else:
sas[s, 1, s] = 1.0 # Bump into wall
# Left
if y > 0:
sas[s, 2, coords_to_state(x, y - 1, grid_size)] = 1.0
else:
sas[s, 2, s] = 1.0 # Bump into wall
# Right
if y < grid_size - 1:
sas[s, 3, coords_to_state(x, y + 1, grid_size)] = 1.0
else:
sas[s, 3, s] = 1.0 # Bump into wall
# Stay
sas[s, 4, s] = 1.0
return reward_function, features, sas, discount, n_states, n_actions
# Simple Gridworld setup
def create_gridworld(
grid_size: int = 15,
discount: float = 0.9,
n_rewards: int = 1,
seed: int = 42,
):
# Initialise RNG
rng = np.random.RandomState(seed)
# Number of states and actions
n_states = grid_size**2
n_actions = 5 # up, down, left, right, stay
# Reward function
reward_function = np.zeros(n_states)
for _ in range(n_rewards):
reward_function[rng.randint(n_states)] = 1.0
# Features matrix: Identity (each state is a feature)
features = np.eye(n_states)
# SAS matrix (state, action, state transition probabilities)
sas = np.zeros((n_states, n_actions, n_states))
# Create transition probabilities for a simple gridworld
def state_to_coords(s, size):
return s // size, s % size
def coords_to_state(x, y, size):
return x * size + y
# Get the state transition probabilities
for s in range(n_states):
x, y = state_to_coords(s, grid_size)
# Up
if x > 0:
sas[s, 0, coords_to_state(x - 1, y, grid_size)] = 1.0
else:
sas[s, 0, s] = 1.0 # Bump into wall
# Down
if x < grid_size - 1:
sas[s, 1, coords_to_state(x + 1, y, grid_size)] = 1.0
else:
sas[s, 1, s] = 1.0 # Bump into wall
# Left
if y > 0:
sas[s, 2, coords_to_state(x, y - 1, grid_size)] = 1.0
else:
sas[s, 2, s] = 1.0 # Bump into wall
# Right
if y < grid_size - 1:
sas[s, 3, coords_to_state(x, y + 1, grid_size)] = 1.0
else:
sas[s, 3, s] = 1.0 # Bump into wall
# Stay
sas[s, 4, s] = 1.0
return reward_function, features, sas, discount, n_states, n_actions
Plot the gridworld environment¶
We'll do one with a single positive goal state.
In [ ]:
Copied!
# Get Gridworld setup
reward_function, features, sas, discount, n_states, n_actions = (
create_gridworld()
)
# Plot
plt.imshow(reward_function.reshape(15, 15))
# Get Gridworld setup
reward_function, features, sas, discount, n_states, n_actions = (
create_gridworld()
)
# Plot
plt.imshow(reward_function.reshape(15, 15))
Out[ ]:
<matplotlib.image.AxesImage at 0x7f86f845a5f0>
And another with a large negative area.
In [ ]:
Copied!
reward_function2 = reward_function.copy()
reward_function2[4:15] = -5
reward_function2[19:30] = -5
reward_function2[34:45] = -5
reward_function2[49:60] = -5
reward_function2[64:75] = -5
reward_function2[79:90] = -5
plt.imshow(reward_function2.reshape(15, 15))
reward_function2 = reward_function.copy()
reward_function2[4:15] = -5
reward_function2[19:30] = -5
reward_function2[34:45] = -5
reward_function2[49:60] = -5
reward_function2[64:75] = -5
reward_function2[79:90] = -5
plt.imshow(reward_function2.reshape(15, 15))
Out[ ]:
<matplotlib.image.AxesImage at 0x7f86c87ec9d0>
Run the value iteration algorithm¶
In [ ]:
Copied!
# Parameters for the value iteration
max_iter = 1000
tol = 1e-4
# Solve
values, q_values = solve_value_iteration(
sas.shape[0],
sas.shape[1],
jnp.array(reward_function),
max_iter,
discount,
jnp.array(sas),
tol,
)
values2, q_values2 = solve_value_iteration(
sas.shape[0],
sas.shape[1],
jnp.array(reward_function2),
max_iter,
discount,
jnp.array(sas),
tol,
)
# Parameters for the value iteration
max_iter = 1000
tol = 1e-4
# Solve
values, q_values = solve_value_iteration(
sas.shape[0],
sas.shape[1],
jnp.array(reward_function),
max_iter,
discount,
jnp.array(sas),
tol,
)
values2, q_values2 = solve_value_iteration(
sas.shape[0],
sas.shape[1],
jnp.array(reward_function2),
max_iter,
discount,
jnp.array(sas),
tol,
)
Plot the value fuction¶
In [31]:
Copied!
f, ax = plt.subplots(1, 2, figsize=(10, 5))
# get grid size
grid_size = int(np.sqrt(n_states))
ax[0].imshow(values.reshape(grid_size, grid_size), interpolation="nearest")
ax[1].imshow(values2.reshape(grid_size, grid_size), interpolation="nearest")
f, ax = plt.subplots(1, 2, figsize=(10, 5))
# get grid size
grid_size = int(np.sqrt(n_states))
ax[0].imshow(values.reshape(grid_size, grid_size), interpolation="nearest")
ax[1].imshow(values2.reshape(grid_size, grid_size), interpolation="nearest")
Out[31]:
<matplotlib.image.AxesImage at 0x7f86c81425f0>