gfn.gym.bitSequenceNonAutoregressive

Non-autoregressive BitSequence environment for GFlowNets.

This environment implements a non-autoregressive version of the bit sequence generation task, where actions encode both the position and word value to place. Unlike the standard (autoregressive) BitSequence environment which appends words left-to-right, this environment allows filling any unfilled position in any order.

This formulation matches the one used by the GFNX (JAX-based) library, enabling fair cross-library benchmarking.

Environment details:
  • State: Tensor of shape (words_per_seq,) with values in {-1, 0, ..., 2^word_size - 1}. -1 indicates an unfilled position.

  • Initial state s0: All positions unfilled, [-1, -1, ..., -1].

  • Terminal states: All positions filled (no -1 values).

  • Forward actions: words_per_seq * n_words actions, where each action a encodes (position, word) = divmod(a, n_words). One additional exit action (the last action) is only available at terminal states.

  • Backward actions: words_per_seq * n_words actions. The backward action for a forward action (pos, word) is the same index — it clears that position back to -1.

  • Reward: Based on the minimum Hamming distance (at the bit level) between the generated sequence and a set of target mode sequences.

Reference:

Malkin, N., Jain, M., Bengio, E., Sun, C., & Bengio, Y. (2022). Trajectory Balance: Improved Credit Assignment in GFlowNets. https://arxiv.org/abs/2201.13259

Classes

NonAutoregressiveBitSequence

Non-autoregressive BitSequence environment.

NonAutoregressiveBitSequenceStates

States for the non-autoregressive BitSequence environment.

Module Contents

class gfn.gym.bitSequenceNonAutoregressive.NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2, reward_exponent=2.0, H=None, device_str='cpu', seed=0, debug=False)

Bases: gfn.env.DiscreteEnv

Non-autoregressive BitSequence environment.

In this environment, the agent constructs a binary sequence by placing words at arbitrary positions. Each action specifies both which position to fill and which word value to place there. The episode ends when all positions are filled.

The reward is based on the minimum Hamming distance (computed at the bit level) between the completed sequence and a set of target “mode” sequences.

Parameters:
  • word_size (int) – Number of bits per word (e.g., 1 for single-bit actions).

  • seq_size (int) – Total number of bits in the sequence. Must be divisible by word_size.

  • n_modes (int) – Number of target mode sequences.

  • reward_exponent (float) – Controls reward sharpness. Higher values make the reward more peaked around the modes.

  • H (Optional[torch.Tensor]) – Optional tensor of shape (n_modes, seq_size) specifying the target modes in binary. If None, modes are generated randomly using block patterns.

  • device_str (str) – Device to use ("cpu" or "cuda").

  • seed (int) – Random seed for mode generation.

  • debug (bool) – If True, enable runtime guards (not compile-friendly).

word_size

Number of bits per word.

seq_size

Total number of bits.

words_per_seq

Number of word positions (seq_size // word_size).

n_words

Number of possible word values (2 ** word_size).

n_modes

Number of target modes.

reward_exponent

Reward sharpness parameter.

modes

Target mode sequences as a binary tensor of shape (n_modes, seq_size).

Example

>>> env = NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2)
>>> # Action space: 4 positions * 2 word values + 1 exit = 9 actions
>>> env.n_actions
9
>>> # State shape: 4 word positions
>>> env.s0
tensor([-1, -1, -1, -1])
H = None
States: type[NonAutoregressiveBitSequenceStates]
_decode_action(action)

Decode a flat action index into (position, word) pair.

Parameters:

action (torch.Tensor) – Action tensor of shape (*batch_shape, 1).

Returns:

Tuple of (position, word) tensors, each of shape (*batch_shape, 1).

Return type:

Tuple[torch.Tensor, torch.Tensor]

static _integers_to_binary(tensor, k)

Convert a tensor of word integers to their binary representation.

Parameters:
  • tensor (torch.Tensor) – Integer tensor of shape (*batch_shape, words_per_seq) with values in {0, ..., 2^k - 1}.

  • k (int) – Number of bits per word.

Returns:

Binary tensor of shape (*batch_shape, words_per_seq * k) with values in {0, 1}.

Return type:

torch.Tensor

_make_modes(seed, device)

Generate target mode sequences in binary representation.

If H is provided, it is used directly as the modes. Otherwise, modes are constructed by randomly combining 8-bit block patterns, following the procedure from the Trajectory Balance paper.

Parameters:
  • seed (int) – Random seed.

  • device (torch.device) – Device to place the modes tensor on.

Returns:

Binary tensor of shape (n_modes, seq_size) with values in {0, 1}.

Return type:

torch.Tensor

static _min_hamming_distance(candidates, references)

Compute minimum Hamming distance from each candidate to any reference.

Parameters:
  • candidates (torch.Tensor) – Binary tensor of shape (*batch_shape, seq_size).

  • references (torch.Tensor) – Binary tensor of shape (n_refs, seq_size).

Returns:

Tensor of shape (*batch_shape,) with the minimum distance.

Return type:

torch.Tensor

backward_step(states, actions)

Undo a word placement by clearing the position back to -1.

The backward action has the same encoding as the forward action: action = position * n_words + word. The word component is used to identify which position to clear.

Parameters:
Returns:

Previous states with the specified positions cleared.

Return type:

NonAutoregressiveBitSequenceStates

log_reward(final_states)

Compute log-reward based on Hamming distance to nearest mode.

The log-reward is:

log R(x) = -reward_exponent * min_d(x, modes) / seq_size

where min_d is the minimum bit-level Hamming distance between the completed sequence and any target mode.

Parameters:

final_states (NonAutoregressiveBitSequenceStates) – Terminal states with all positions filled.

Returns:

Log-reward tensor of shape (*batch_shape,).

Return type:

torch.Tensor

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generate random partially-filled states.

Each position is independently either unfilled (-1) or filled with a random word value.

Parameters:
  • batch_shape (Tuple) – Shape of the batch.

  • conditions (Optional[torch.Tensor]) – Optional conditions tensor.

  • device (Optional[torch.device]) – Device to use.

  • debug (bool) – If True, enable debug mode.

Returns:

Random states.

Return type:

NonAutoregressiveBitSequenceStates

make_states_class()

Create the States class with environment-specific constants.

Return type:

type[NonAutoregressiveBitSequenceStates]

modes
n_modes_count = 2
property n_terminating_states: int

Total number of possible terminal states.

Return type:

int

n_words = 2
reward(final_states)

Compute reward as exp(log_reward).

Parameters:

final_states (NonAutoregressiveBitSequenceStates) – Terminal states.

Returns:

Reward tensor of shape (*batch_shape,).

Return type:

torch.Tensor

reward_exponent = 2.0
seq_size = 4
step(states, actions)

Place a word at the specified position.

The action encodes (position, word) as a flat index: action = position * n_words + word.

Parameters:
Returns:

Next states with the specified positions filled.

Return type:

NonAutoregressiveBitSequenceStates

property terminating_states: NonAutoregressiveBitSequenceStates

Enumerate all terminal states (only feasible for small environments).

Return type:

NonAutoregressiveBitSequenceStates

true_dist(condition=None)

Compute the true reward distribution over all terminal states.

Return type:

torch.Tensor

word_size = 1
words_per_seq = 4
class gfn.gym.bitSequenceNonAutoregressive.NonAutoregressiveBitSequenceStates(tensor, conditions=None, device=None, debug=False)

Bases: gfn.states.DiscreteStates

States for the non-autoregressive BitSequence environment.

Each state is a tensor of shape (words_per_seq,) where each element is either -1 (unfilled) or a word value in {0, ..., n_words - 1}.

Parameters:
  • tensor (torch.Tensor)

  • conditions (Optional[torch.Tensor])

  • device (torch.device | None)

  • debug (bool)

word_size

Number of bits per word.

words_per_seq

Number of word positions in the sequence.

n_words

Number of possible word values (2 ** word_size).

_compute_backward_masks()

Compute which backward actions are valid at each state.

A backward action (pos, word) is valid iff position pos currently holds that exact word value.

Returns:

Boolean tensor of shape (*batch_shape, n_actions - 1).

Return type:

torch.Tensor

_compute_forward_masks()

Compute which forward actions are valid at each state.

An action (pos, word) is valid iff position pos is unfilled (value == -1). All n_words word choices for a given position share the same validity. The exit action is only valid when all positions are filled.

Returns:

Boolean tensor of shape (*batch_shape, n_actions).

Return type:

torch.Tensor

n_words: ClassVar[int]
to_str()

Convert states to human-readable binary strings.

Returns:

List of binary strings, one per state in the flattened batch.

Return type:

List[str]

word_size: ClassVar[int]
words_per_seq: ClassVar[int]