gfn.gym.bitSequenceNonAutoregressive¶
Non-autoregressive BitSequence environment for GFlowNets.
This environment implements a non-autoregressive version of the bit sequence generation task, where actions encode both the position and word value to place. Unlike the standard (autoregressive) BitSequence environment which appends words left-to-right, this environment allows filling any unfilled position in any order.
This formulation matches the one used by the GFNX (JAX-based) library, enabling fair cross-library benchmarking.
- Environment details:
State: Tensor of shape
(words_per_seq,)with values in{-1, 0, ..., 2^word_size - 1}.-1indicates an unfilled position.Initial state
s0: All positions unfilled,[-1, -1, ..., -1].Terminal states: All positions filled (no
-1values).Forward actions:
words_per_seq * n_wordsactions, where each actionaencodes(position, word) = divmod(a, n_words). One additional exit action (the last action) is only available at terminal states.Backward actions:
words_per_seq * n_wordsactions. The backward action for a forward action(pos, word)is the same index — it clears that position back to-1.Reward: Based on the minimum Hamming distance (at the bit level) between the generated sequence and a set of target mode sequences.
- Reference:
Malkin, N., Jain, M., Bengio, E., Sun, C., & Bengio, Y. (2022). Trajectory Balance: Improved Credit Assignment in GFlowNets. https://arxiv.org/abs/2201.13259
Classes¶
Non-autoregressive BitSequence environment. |
|
States for the non-autoregressive BitSequence environment. |
Module Contents¶
- class gfn.gym.bitSequenceNonAutoregressive.NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2, reward_exponent=2.0, H=None, device_str='cpu', seed=0, debug=False)¶
Bases:
gfn.env.DiscreteEnvNon-autoregressive BitSequence environment.
In this environment, the agent constructs a binary sequence by placing words at arbitrary positions. Each action specifies both which position to fill and which word value to place there. The episode ends when all positions are filled.
The reward is based on the minimum Hamming distance (computed at the bit level) between the completed sequence and a set of target “mode” sequences.
- Parameters:
word_size (int) – Number of bits per word (e.g., 1 for single-bit actions).
seq_size (int) – Total number of bits in the sequence. Must be divisible by
word_size.n_modes (int) – Number of target mode sequences.
reward_exponent (float) – Controls reward sharpness. Higher values make the reward more peaked around the modes.
H (Optional[torch.Tensor]) – Optional tensor of shape
(n_modes, seq_size)specifying the target modes in binary. If None, modes are generated randomly using block patterns.device_str (str) – Device to use (
"cpu"or"cuda").seed (int) – Random seed for mode generation.
debug (bool) – If True, enable runtime guards (not compile-friendly).
- word_size¶
Number of bits per word.
- seq_size¶
Total number of bits.
- words_per_seq¶
Number of word positions (
seq_size // word_size).
- n_words¶
Number of possible word values (
2 ** word_size).
- n_modes¶
Number of target modes.
- reward_exponent¶
Reward sharpness parameter.
- modes¶
Target mode sequences as a binary tensor of shape
(n_modes, seq_size).
Example
>>> env = NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2) >>> # Action space: 4 positions * 2 word values + 1 exit = 9 actions >>> env.n_actions 9 >>> # State shape: 4 word positions >>> env.s0 tensor([-1, -1, -1, -1])
- H = None¶
- States: type[NonAutoregressiveBitSequenceStates]¶
- _decode_action(action)¶
Decode a flat action index into (position, word) pair.
- Parameters:
action (torch.Tensor) – Action tensor of shape
(*batch_shape, 1).- Returns:
Tuple of (position, word) tensors, each of shape
(*batch_shape, 1).- Return type:
Tuple[torch.Tensor, torch.Tensor]
- static _integers_to_binary(tensor, k)¶
Convert a tensor of word integers to their binary representation.
- Parameters:
tensor (torch.Tensor) – Integer tensor of shape
(*batch_shape, words_per_seq)with values in{0, ..., 2^k - 1}.k (int) – Number of bits per word.
- Returns:
Binary tensor of shape
(*batch_shape, words_per_seq * k)with values in{0, 1}.- Return type:
torch.Tensor
- _make_modes(seed, device)¶
Generate target mode sequences in binary representation.
If
His provided, it is used directly as the modes. Otherwise, modes are constructed by randomly combining 8-bit block patterns, following the procedure from the Trajectory Balance paper.- Parameters:
seed (int) – Random seed.
device (torch.device) – Device to place the modes tensor on.
- Returns:
Binary tensor of shape
(n_modes, seq_size)with values in {0, 1}.- Return type:
torch.Tensor
- static _min_hamming_distance(candidates, references)¶
Compute minimum Hamming distance from each candidate to any reference.
- Parameters:
candidates (torch.Tensor) – Binary tensor of shape
(*batch_shape, seq_size).references (torch.Tensor) – Binary tensor of shape
(n_refs, seq_size).
- Returns:
Tensor of shape
(*batch_shape,)with the minimum distance.- Return type:
torch.Tensor
- backward_step(states, actions)¶
Undo a word placement by clearing the position back to -1.
The backward action has the same encoding as the forward action:
action = position * n_words + word. The word component is used to identify which position to clear.- Parameters:
states (NonAutoregressiveBitSequenceStates) – Current states.
actions (gfn.actions.Actions) – Backward actions to undo.
- Returns:
Previous states with the specified positions cleared.
- Return type:
- log_reward(final_states)¶
Compute log-reward based on Hamming distance to nearest mode.
- The log-reward is:
log R(x) = -reward_exponent * min_d(x, modes) / seq_size
where
min_dis the minimum bit-level Hamming distance between the completed sequence and any target mode.- Parameters:
final_states (NonAutoregressiveBitSequenceStates) – Terminal states with all positions filled.
- Returns:
Log-reward tensor of shape
(*batch_shape,).- Return type:
torch.Tensor
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generate random partially-filled states.
Each position is independently either unfilled (-1) or filled with a random word value.
- Parameters:
batch_shape (Tuple) – Shape of the batch.
conditions (Optional[torch.Tensor]) – Optional conditions tensor.
device (Optional[torch.device]) – Device to use.
debug (bool) – If True, enable debug mode.
- Returns:
Random states.
- Return type:
- make_states_class()¶
Create the States class with environment-specific constants.
- Return type:
- modes¶
- n_modes_count = 2¶
- property n_terminating_states: int¶
Total number of possible terminal states.
- Return type:
int
- n_words = 2¶
- reward(final_states)¶
Compute reward as
exp(log_reward).- Parameters:
final_states (NonAutoregressiveBitSequenceStates) – Terminal states.
- Returns:
Reward tensor of shape
(*batch_shape,).- Return type:
torch.Tensor
- reward_exponent = 2.0¶
- seq_size = 4¶
- step(states, actions)¶
Place a word at the specified position.
The action encodes
(position, word)as a flat index:action = position * n_words + word.- Parameters:
states (NonAutoregressiveBitSequenceStates) – Current states.
actions (gfn.actions.Actions) – Actions encoding (position, word) pairs.
- Returns:
Next states with the specified positions filled.
- Return type:
- property terminating_states: NonAutoregressiveBitSequenceStates¶
Enumerate all terminal states (only feasible for small environments).
- Return type:
- true_dist(condition=None)¶
Compute the true reward distribution over all terminal states.
- Return type:
torch.Tensor
- word_size = 1¶
- words_per_seq = 4¶
- class gfn.gym.bitSequenceNonAutoregressive.NonAutoregressiveBitSequenceStates(tensor, conditions=None, device=None, debug=False)¶
Bases:
gfn.states.DiscreteStatesStates for the non-autoregressive BitSequence environment.
Each state is a tensor of shape
(words_per_seq,)where each element is either-1(unfilled) or a word value in{0, ..., n_words - 1}.- Parameters:
tensor (torch.Tensor)
conditions (Optional[torch.Tensor])
device (torch.device | None)
debug (bool)
- word_size¶
Number of bits per word.
- words_per_seq¶
Number of word positions in the sequence.
- n_words¶
Number of possible word values (
2 ** word_size).
- _compute_backward_masks()¶
Compute which backward actions are valid at each state.
A backward action
(pos, word)is valid iff positionposcurrently holds that exact word value.- Returns:
Boolean tensor of shape
(*batch_shape, n_actions - 1).- Return type:
torch.Tensor
- _compute_forward_masks()¶
Compute which forward actions are valid at each state.
An action
(pos, word)is valid iff positionposis unfilled (value == -1). Alln_wordsword choices for a given position share the same validity. The exit action is only valid when all positions are filled.- Returns:
Boolean tensor of shape
(*batch_shape, n_actions).- Return type:
torch.Tensor
- n_words: ClassVar[int]¶
- to_str()¶
Convert states to human-readable binary strings.
- Returns:
List of binary strings, one per state in the flattened batch.
- Return type:
List[str]
- word_size: ClassVar[int]¶
- words_per_seq: ClassVar[int]¶