Reward Machines & Automata

Introduction to Reward Machines

Reinforcement learning agents typically learn from simple reward signals that are functions of the current state or state-action pair. While effective for basic tasks, this approach struggles with temporally extended goals that depend on complex sequences of events. Reward Machines (RMs) address this limitation by providing a formal framework to specify tasks through automata that map sequences of events to rewards. These machines:
  • Define rewards based on the history of events, not just the current state
  • Enable specification of temporally extended tasks with complex logic
  • Allow for structured, interpretable task definitions

Types of Reward Machines

The CRM framework supports two main types of Reward Machines:

Vanilla Reward Machines

Vanilla Reward Machines are a special case of Counting Reward Machines and define rewards based on sequences of symbolic events, without an extended automaton memory. They consist of:
  • States: Different phases of the task
  • Transitions: Rules for moving between states based on symbolic events
  • Rewards: Values assigned during transitions

Counting Reward Machines (CRMs)

Counting Reward Machines extend Vanilla RMs with counter variables as an extended automaton memory that can track occurrences of events. They add:
  • Counters: Variables that increment or decrement based on events
  • Counter Conditions: Rules that use counter values to determine transitions
This extension significantly increases the expressive power of RMs to include any task which can be expressed as a well-defined algorithm.

Components of a Reward Machine

All Reward Machines in the CRM framework share these core components:

1. States and Initial State

States represent different phases or stages of a task. Each RM has:
  • A finite set of states
  • An initial state
  • (Optionally) terminal states that end the task
@property
def u_0(self) -> int:
    """Return the initial state of the machine."""
    return 0

2. Transition Functions

Three key functions define the behavior of a Reward Machine:

State Transition Function (delta_u)

Defines how the machine moves between states:
def _get_state_transition_function(self) -> dict:
    return {
        0: {  # From state 0
            "EVENT_A / (-)": 1,  # When EVENT_A occurs, go to state 1
            "/ (-)": 0,          # Default: stay in state 0
        },
        1: {  # From state 1
            "EVENT_B / (-)": -1,  # When EVENT_B occurs, go to terminal state
            "/ (-)": 1,           # Default: stay in state 1
        }
    }

Counter Transition Function (delta_c)

Defines how counter values change (for CRMs):
def _get_counter_transition_function(self) -> dict:
    return {
        0: {  # From state 0
            "EVENT_A / (-)": (1,),  # Increment counter by 1
            "/ (-)": (0,),          # Default: no change
        },
        1: {  # From state 1
            "EVENT_B / (-)": (0,),  # No change to counter
            "/ (-)": (0,),          # Default: no change
        }
    }

Reward Transition Function (delta_r)

Defines rewards given during transitions:
def _get_reward_transition_function(self) -> dict:
    return {
        0: {  # From state 0
            "EVENT_A / (-)": 0.1,  # Small reward for EVENT_A
            "/ (-)": -0.01,        # Small penalty for no event
        },
        1: {  # From state 1
            "EVENT_B / (-)": 1.0,  # Large reward for EVENT_B
            "/ (-)": -0.01,        # Small penalty for no event
        }
    }

3. Counters (for CRMs)

Counting Reward Machines maintain counter variables:
@property
def c_0(self) -> tuple[int, ...]:
    """Return the initial counter configuration."""
    return (0,)  # Start with a single counter at 0

Creating a Vanilla Reward Machine

Vanilla Reward Machines are ideal for tasks that involve sequences of events without counting. Here’s a complete example:
from crm.automaton import CountingRewardMachine
from my_environment.events import Event  # Your event enum

class SequenceRM(CountingRewardMachine):
    """Reward machine for a simple sequence: A -> B -> C."""
    
    def __init__(self):
        super().__init__(env_prop_enum=Event)
    
    @property
    def u_0(self) -> int:
        """Initial state."""
        return 0
    
    @property
    def c_0(self) -> tuple[int, ...]:
        """Initial counter (unused)."""
        return (0,)
    
    @property
    def encoded_configuration_size(self) -> int:
        return 1
    
    def _get_state_transition_function(self) -> dict:
        return {
            0: {"A / (-)": 1, "/ (-)": 0},
            1: {"B / (-)": 2, "/ (-)": 1},
            2: {"C / (-)": -1, "/ (-)": 2}, # -1 denotes terminal state
        }
    
    def _get_counter_transition_function(self) -> dict:
        # Counters don't change in a Vanilla RM
        return {
            0: {"A / (-)": (0,), "/ (-)": (0,)},
            1: {"B / (-)": (0,), "/ (-)": (0,)},
            2: {"C / (-)": (0,), "/ (-)": (0,)},
        }
    
    def _get_reward_transition_function(self) -> dict:
        return {
            0: {"A / (-)": 0.1, "/ (-)": -0.01},
            1: {"B / (-)": 0.2, "/ (-)": -0.01},
            2: {"C / (-)": 1.0, "/ (-)": -0.01},
        }
    
    def _get_possible_counter_configurations(self) -> list[tuple[int]]:
        return [(0,)]  # Only one possible configuration
        
    def sample_counter_configurations(self) -> list[tuple[int]]:
        return self._get_possible_counter_configurations()
This RM rewards the agent for observing events A, then B, then C in sequence.

Simplified Syntax for Vanilla RMs

For Vanilla Reward Machines, the transition expressions can be simplified to just the event expressions without the counter component:
def _get_state_transition_function(self) -> dict:
    return {
        0: {
            "A": 1,  # Simply use the event name
            "": 0,   # Empty string for default transition
        },
        1: {
            "B": 2,
            "": 1,
        },
    }

Creating a Counting Reward Machine

Counting Reward Machines are used for more complex task specifications. Here’s an example that counts and balances events:
class BalancedCRM(CountingRewardMachine):
    """CRM that requires equal numbers of A and C events."""
    
    def __init__(self):
        super().__init__(env_prop_enum=Event)
    
    @property
    def u_0(self) -> int:
        return 0
    
    @property
    def c_0(self) -> tuple[int, ...]:
        return (0,)  # Start with counter at 0
    
    @property
    def encoded_configuration_size(self) -> int:
        return 2
    
    def _get_state_transition_function(self) -> dict:
        return {
            0: {
                "A / (-)": 0,  # Stay in state 0 when seeing A
                "B / (-)": 1,  # Move to state 1 when seeing B
                "C / (-)": 0,  # Stay in state 0 when seeing C
                "/ (-)": 0,    # Default: stay in state 0
            },
            1: {
                "A / (-)": 1,  # Stay in state 1 for any A
                "B / (-)": 1,  # Stay in state 1 for any B
                "C / (NZ)": 1,  # Stay in state 1 for C if counter > 0
                "C / (Z)": -1,  # Terminal state for C if counter = 0
                "/ (-)": 1,     # Default: stay in state 1
            },
        }
    
    def _get_counter_transition_function(self) -> dict:
        return {
            0: {
                "A / (-)": (1,),  # Increment counter for A
                "B / (-)": (0,),  # No change for B
                "C / (-)": (0,),  # No change for C in state 0
                "/ (-)": (0,),    # Default: no change
            },
            1: {
                "A / (-)": (0,),     # No change for A in state 1
                "B / (-)": (0,),     # No change for B
                "C / (NZ)": (-1,),   # Decrement counter for C if > 0
                "C / (Z)": (0,),     # No change for C if counter = 0
                "/ (-)": (0,),       # Default: no change
            },
        }
    
    def _get_reward_transition_function(self) -> dict:
        return {
            0: {
                "A / (-)": -0.1,  # Small negative reward (step cost)
                "B / (-)": -0.1,
                "C / (-)": -0.1,
                "/ (-)": -0.1,
            },
            1: {
                "A / (-)": -0.1,
                "B / (-)": -0.1,
                "C / (NZ)": -0.1,
                "C / (Z)": 1.0,   # Positive reward for completion
                "/ (-)": -0.1,
            },
        }
    
    def _get_possible_counter_configurations(self) -> list[tuple[int]]:
        return [(i,) for i in range(6)]  # Counter can be 0-5
    
    def sample_counter_configurations(self) -> list[tuple[int]]:
        return self._get_possible_counter_configurations()
This CRM defines a task that:
  1. Counts occurrences of A in state 0
  2. Transitions to state 1 when B is observed
  3. Decrements the counter for each C in state 1
  4. Completes when the counter reaches 0 (equal numbers of A and C)

Transition Expression Syntax

Reward Machines use a specific syntax for transition expressions:
"EVENT_EXPRESSION / COUNTER_CONDITION"

Event Expressions (Well-Formed Formulas)

The event expression part supports Boolean logic, allowing for complex combinations:
  • Single event: "A / (-)"
  • Logical AND: "A and B / (-)" (both A and B must occur)
  • Logical OR: "A or B / (-)" (either A or B must occur)
  • Logical NOT: "not A / (-)" (A must not occur)
  • Complex expressions:
    • "A and not B / (-)" (A must occur and B must not occur)
    • "not (A or B) / (-)" (neither A nor B can occur)
    • "(A and B) or C / (-)" (either both A and B, or C must occur)
These are called well-formed formulas (WFFs) and provide powerful logical conditions.

Counter Conditions

For Counting Reward Machines, counter conditions specify requirements for counters:
  • Any value: "A / (-)"
  • Zero: "A / (Z)"
  • Non-zero: "A / (NZ)"
  • Multiple counters: "A / (Z,NZ)" (first counter zero, second non-zero)

Advanced Features

Custom Reward Functions

Instead of constant rewards, RMs can use callable functions that compute rewards based on the specific details of the transition:
def distance_based_reward(obs, action, next_obs):
    """Reward based on distance to goal."""
    goal_position = np.array([10, 10])
    agent_position = next_obs[1:3]
    distance = np.linalg.norm(agent_position - goal_position)
    return 1.0 / (1.0 + distance)  # Higher reward when closer

def _get_reward_transition_function(self) -> dict:
    return {
        0: {
            "GOAL_VISIBLE / (-)": distance_based_reward,
            "/ (-)": -0.1,
        }
    }

Full Reward Functions

The reward transition function (delta_r) can return either:
  1. Scalar values: Simple numerical rewards (converted internally to constant functions)
  2. Callable functions: Full reward functions that receive the transition details
When using callable reward functions, they receive the raw observations and actions from the environment:
def time_based_penalty(obs, action, next_obs):
    """Penalty that increases over time."""
    time_step = next_obs[0]  # Assuming first element tracks time
    return -0.01 * time_step  # Greater penalty as time increases

def combined_reward(obs, action, next_obs):
    """Combine multiple reward components."""
    progress_reward = calculate_progress(obs, next_obs)
    energy_penalty = -0.1 * calculate_energy(action)
    return progress_reward + energy_penalty
This approach enables sophisticated reward shaping strategies that depend on:
  • Continuous state variables
  • Action costs or complexities
  • Time-based factors
  • Multiple weighted reward components

Multiple Counters

For more complex tasks, you can use multiple counters:
@property
def c_0(self) -> tuple[int, ...]:
    return (0, 0)  # Two counters, both starting at 0

def _get_counter_transition_function(self) -> dict:
    return {
        0: {
            "A / (-,-)": (1, 0),  # Increment first counter only
            "B / (-,-)": (0, 1),  # Increment second counter only
            "C / (-,-)": (1, 1),  # Increment both counters
            "/ (-,-)": (0, 0),    # No change to any counter
        }
    }
Counter conditions can then reference specific counters:
def _get_state_transition_function(self) -> dict:
    return {
        0: {
            "A / (Z,NZ)": 1,  # A with first counter=0, second counter>0
            "B / (NZ,Z)": 2,  # B with first counter>0, second counter=0
            "/ (-,-)": 0,
        }
    }

Complex Event Conditions

The transition expressions support rich Boolean logic on events:
def _get_state_transition_function(self) -> dict:
    return {
        0: {
            "A and B / (-)": 1,      # Both A and B must occur
            "A and not C / (-)": 2,  # A must occur, C must not
            "not (A or B) / (-)": 3, # Neither A nor B can occur
            "/ (-)": 0,
        }
    }
These expressions are evaluated based on the events detected by the labelling function. For a transition to match:
  1. The Boolean logic on events must evaluate to true
  2. The counter conditions must be satisfied

Best Practices

When creating Reward Machines:
  1. Start Simple: Begin with Vanilla RMs before adding counters
  2. Use Meaningful States: Design states to represent logical phases of the task
  3. Consistent Structure: Ensure all three transition functions have the same keys
  4. Default Transitions: Always include "/ (-)" transitions for all states
  5. Terminal States: Use -1 to indicate task completion or failure
  6. Counter Limits: Set bounds on counter values to prevent unbounded counting
  7. Test Transitions: Verify that all possible event combinations are handled

Summary

Counting Reward Machines provide a powerful and flexible framework for specifying complex tasks in reinforcement learning:
  • Vanilla Reward Machines define tasks based on sequences of events
  • Counting Reward Machines add counters to enable more complex task specifications
  • The formalism cleanly separates task specification from environment dynamics
  • Boolean logic allows for complex event conditions
  • Custom reward functions enable sophisticated reward shaping
By using Counting Reward Machines, you can create agents that learn to solve tasks requiring memory, counting, and complex sequential behavior in a structured, interpretable way.