Counterfactual Q-Learning in Letter World

Counterfactual Q-Learning uses the structure of Counting Reward Machines to generate additional learning experiences, dramatically improving sample efficiency.

The Concept

Counterfactual Q-Learning leverages the structured nature of Counting Reward Machines to generate additional “what-if” experiences that the agent can learn from, without actually having to explore those states. By utilizing the symbolic representation in the CRM, we can:
  • Infer consequences of different actions
  • Update multiple state-action values at once
  • Accelerate learning significantly

Implementation Comparison

Let’s compare standard Q-Learning with Counterfactual Q-Learning in the Letter World environment.
The examples are not distributed in the PyPI package currently. Please see the Installation Guide for information on how to setup a development build.

Standard Q-Learning

from examples.letter.ground import LetterWorld
from examples.letter.label import LetterWorldLabellingFunction
from examples.letter.machine import LetterWorldCountingRewardMachine
from examples.letter.crossproduct import LetterWorldCrossProduct

# 1. Create the ground environment
ground_env = LetterWorld()

# 2. Create the labelling function
lf = LetterWorldLabellingFunction()

# 3. Create the CRM
crm = LetterWorldCountingRewardMachine()

# 4. Create the cross-product MDP
env = LetterWorldCrossProduct(
    ground_env=ground_env,
    crm=crm,
    lf=lf,
    max_steps=500,
)

# Hyperparameters
EPISODES = 5000
LEARNING_RATE = 0.01
DISCOUNT_FACTOR = 0.99
EPSILON = 0.1

# Initialize Q-table
q_table = defaultdict(lambda: np.zeros(env.action_space.n))

# Train standard Q-learning agent
for episode in range(EPISODES):
    obs, _ = env.reset()
    done = False
    
    while not done:
        # Epsilon-greedy policy
        if np.random.random() < EPSILON:
            action = np.random.randint(env.action_space.n)
        else:
            action = int(np.argmax(q_table[tuple(obs)]))
            
        # Execute action
        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        
        # Q-learning update
        if done:
            td_target = reward
        else:
            td_target = reward + DISCOUNT_FACTOR * np.max(q_table[tuple(next_obs)])
            
        td_error = td_target - q_table[tuple(obs)][action]
        q_table[tuple(obs)][action] += LEARNING_RATE * td_error
        
        obs = next_obs

Counterfactual Q-Learning

# Initialize Q-table for counterfactual agent
q_table = defaultdict(lambda: np.zeros(env.action_space.n))

# Train counterfactual Q-learning agent
for episode in range(EPISODES):
    obs, _ = env.reset()
    done = False
    
    while not done:
        # Epsilon-greedy policy
        if np.random.random() < EPSILON:
            action = np.random.randint(env.action_space.n)
        else:
            action = int(np.argmax(q_table[tuple(obs)]))
            
        # Execute action
        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        
        # Generate and learn from counterfactual experiences
        for o, a, o_, r, d, _ in zip(
            *env.generate_counterfactual_experience(
                env.to_ground_obs(obs),
                action,
                env.to_ground_obs(next_obs),
            ),
            strict=True,
        ):
            if not d:
                q_table[tuple(o)][a] += LEARNING_RATE * (
                    r
                    + DISCOUNT_FACTOR * np.max(q_table[tuple(o_)])
                    - q_table[tuple(o)][a]
                )
            else:
                q_table[tuple(o)][a] += LEARNING_RATE * (r - q_table[tuple(o)][a])
        
        # Standard Q-learning update for actual experience
        if done:
            td_target = reward
        else:
            td_target = reward + DISCOUNT_FACTOR * np.max(q_table[tuple(next_obs)])
            
        td_error = td_target - q_table[tuple(obs)][action]
        q_table[tuple(obs)][action] += LEARNING_RATE * td_error
        
        obs = next_obs

Performance Comparison

When we run both algorithms on the Letter World environment, we see a significant improvement in learning efficiency with Counterfactual Q-Learning:
As the graph shows:
  • Counterfactual Q-Learning (orange) learns much faster and achieves higher returns
  • Standard Q-Learning (blue) requires many more episodes to approach similar performance
  • Both eventually converge, but counterfactual learning requires significantly fewer samples

How It Works

The counterfactual learning process works through these steps:
  1. The agent takes an action in the environment
  2. The CRM uses its symbolic structure to infer what would have happened for other state-action pairs
  3. These counterfactual experiences are used to update multiple Q-values simultaneously
  4. This process effectively multiplies the learning signal from each real experience

Code Breakdown: Generating Counterfactual Experiences

The magic happens in the generate_counterfactual_experience method:
def generate_counterfactual_experience(self, ground_obs, action, next_ground_obs):
    # Current position and CRM state
    current_pos = self.ground_env.decode_obs(ground_obs)
    next_pos = self.ground_env.decode_obs(next_ground_obs)
    
    # Generate all possible experiences based on CRM structure
    obs_list = []
    action_list = []
    next_obs_list = []
    reward_list = []
    done_list = []
    info_list = []
    
    # For each possible CRM state...
    for crm_state in self.crm.states:
        # Create counterfactual experience
        # [implementation details]
        
    return obs_list, action_list, next_obs_list, reward_list, done_list, info_list

Benefits and Applications

Counterfactual Q-Learning offers several advantages:
  • Sample Efficiency: Learns from fewer real-world interactions
  • Faster Convergence: Reaches optimal policy more quickly
  • Better Exploration: Effectively explores the state space
  • Interpretability: Leverages symbolic structure of CRMs
This approach is particularly useful in environments where:
  • Exploration is expensive or risky
  • Task structures have clear symbolic representations
  • Sample efficiency is critical

Conclusion

Counterfactual Q-Learning demonstrates a powerful approach to accelerating learning in environments with structured symbolic representations. By leveraging the CRM’s symbolic structure, we can generate additional learning experiences that significantly improve sample efficiency and convergence speed.