Labelling Functions

Overview

In the Counting Reward Machine (CRM) framework, labelling functions play a critical role in abstracting raw environment observations into meaningful symbolic events. These events form the vocabulary that the CRM uses to define rewards and transitions.

Core Concept

A labelling function maps an environment transition (observation, action, next observation) to a set of symbolic events (a truth assignment for a set of propositions). This abstraction layer provides several benefits:
  • Modularity: Separate environment dynamics from task specification
  • Reusability: Same environment can support different tasks by changing the labelling
  • Expressiveness: Define complex tasks in terms of abstract events rather than raw states

Implementation

Base Class

The CRM framework provides a base LabellingFunction class that handles the core functionality:
from abc import ABC
from enum import Enum
from typing import Any, Generic, TypeVar

ObsType = TypeVar("ObsType")  # Type of observations
ActType = TypeVar("ActType")  # Type of actions

class LabellingFunction(ABC, Generic[ObsType, ActType]):
    """Base class for labelling functions."""
    
    @staticmethod
    def event(func):
        """Decorator to register an event detection method."""
        setattr(func, "_is_event_method", True)
        return func
        
    def __call__(self, obs: ObsType, action: ActType, next_obs: ObsType) -> set[Enum]:
        """Detect all events in a transition."""
        # Implementation details...
The base class provides:
  1. Type parameters for observation and action types
  2. An event decorator to identify event detection methods
  3. A __call__ method that runs all event detectors on a transition

Creating a Labelling Function

To create a custom labelling function, follow these steps:
  1. Define your events as an Enum:
    class Symbol(Enum):
        A = auto()
        B = auto()
        C = auto()
    
  2. Create a subclass of LabellingFunction:
    class MyLabellingFunction(LabellingFunction[MyObsType, MyActType]):
        # Implementation...
    
  3. Implement event detection methods with the @event decorator:
    @LabellingFunction.event
    def detect_event_x(self, obs, action, next_obs) -> Symbol | None:
        # Logic to detect event X
        if condition_met:
            return Symbol.X
        return None
    
Each event method should:
  • Accept the parameters (self, obs, action, next_obs)
  • Return an Enum value when the event is detected
  • Return None when the event is not detected

How It Works

When you call a labelling function with lf(obs, action, next_obs), it:
  1. Looks for all methods decorated with @LabellingFunction.event
  2. Calls each method with the provided transition
  3. Collects all non-None results into a set
  4. Returns this set of detected events

Example: Letter World

Let’s examine a complete example from the Letter World environment:
from enum import Enum, auto
import numpy as np
from crm.label import LabellingFunction

class Symbol(Enum):
    """Symbols in the Letter World environment."""
    A = auto()
    B = auto()
    C = auto()

class LetterWorldLabellingFunction(LabellingFunction[np.ndarray, int]):
    """Labelling function for the Letter World environment."""

    A_B_POSITION = np.array([1, 1])
    C_POSITION = np.array([1, 5])

    @LabellingFunction.event
    def test_symbol_a(
        self, obs: np.ndarray, action: np.ndarray, next_obs: np.ndarray
    ) -> Symbol | None:
        """Return if the agent observes symbol A."""
        if next_obs[0] == 0 and np.array_equal(next_obs[1:], self.A_B_POSITION):
            return Symbol.A
        return None

    @LabellingFunction.event
    def test_symbol_b(
        self, obs: np.ndarray, action: np.ndarray, next_obs: np.ndarray
    ) -> Symbol | None:
        """Return if the agent observes symbol B."""
        if next_obs[0] == 1 and np.array_equal(next_obs[1:], self.A_B_POSITION):
            return Symbol.B
        return None

    @LabellingFunction.event
    def test_symbol_c(
        self, obs: np.ndarray, action: np.ndarray, next_obs: np.ndarray
    ) -> Symbol | None:
        """Return if the agent observes symbol C."""
        if next_obs[0] == 1 and np.array_equal(next_obs[1:], self.C_POSITION):
            return Symbol.C
        return None
This example:
  1. Defines three symbols (A, B, C) as events
  2. Creates a labelling function that works with numpy array observations and integer actions
  3. Implements three event detection methods, each checking specific conditions
  4. Returns the appropriate symbol when conditions are met

Using Labelling Functions

Once defined, you use a labelling function as part of a cross-product environment:
# Create components
ground_env = LetterWorld()
lf = LetterWorldLabellingFunction()
crm = LetterWorldCountingRewardMachine()

# Create cross-product environment
env = LetterWorldCrossProduct(
    ground_env=ground_env,
    crm=crm,
    lf=lf,
    max_steps=500,
)

# When env.step() is called, the labelling function
# automatically converts observations to events
next_obs, reward, terminated, truncated, info = env.step(action)
detected_events = info.get('events', [])  # Events detected in this step

Best Practices

When creating labelling functions:
  1. Keep events meaningful: Define events that have semantic meaning for your task
  2. Be consistent: Ensure event detection is consistent and deterministic
  3. Handle edge cases: Consider all possible transitions and what they should mean
  4. Document conditions: Clearly document when each event is triggered
  5. Use clear naming: Name your event detection methods descriptively

Summary

Labelling functions are a powerful abstraction mechanism in the CRM framework. They:
  • Convert low-level observations to high-level events
  • Allow for modular task specification
  • Enable the same environment to support different tasks
  • Provide a clean separation between environment dynamics and reward logic
By properly designing your labelling function, you can create complex tasks that would be difficult to specify using traditional reward functions.