gfn.gym.bitSequenceNonAutoregressive ==================================== .. py:module:: gfn.gym.bitSequenceNonAutoregressive .. autoapi-nested-parse:: Non-autoregressive BitSequence environment for GFlowNets. This environment implements a non-autoregressive version of the bit sequence generation task, where actions encode both the position and word value to place. Unlike the standard (autoregressive) BitSequence environment which appends words left-to-right, this environment allows filling any unfilled position in any order. This formulation matches the one used by the GFNX (JAX-based) library, enabling fair cross-library benchmarking. Environment details: - State: Tensor of shape ``(words_per_seq,)`` with values in ``{-1, 0, ..., 2^word_size - 1}``. ``-1`` indicates an unfilled position. - Initial state ``s0``: All positions unfilled, ``[-1, -1, ..., -1]``. - Terminal states: All positions filled (no ``-1`` values). - Forward actions: ``words_per_seq * n_words`` actions, where each action ``a`` encodes ``(position, word) = divmod(a, n_words)``. One additional exit action (the last action) is only available at terminal states. - Backward actions: ``words_per_seq * n_words`` actions. The backward action for a forward action ``(pos, word)`` is the same index — it clears that position back to ``-1``. - Reward: Based on the minimum Hamming distance (at the bit level) between the generated sequence and a set of target mode sequences. Reference: Malkin, N., Jain, M., Bengio, E., Sun, C., & Bengio, Y. (2022). Trajectory Balance: Improved Credit Assignment in GFlowNets. https://arxiv.org/abs/2201.13259 Classes ------- .. autoapisummary:: gfn.gym.bitSequenceNonAutoregressive.NonAutoregressiveBitSequence gfn.gym.bitSequenceNonAutoregressive.NonAutoregressiveBitSequenceStates Module Contents --------------- .. py:class:: NonAutoregressiveBitSequence(word_size = 1, seq_size = 4, n_modes = 2, reward_exponent = 2.0, H = None, device_str = 'cpu', seed = 0, debug = False) Bases: :py:obj:`gfn.env.DiscreteEnv` Non-autoregressive BitSequence environment. In this environment, the agent constructs a binary sequence by placing words at arbitrary positions. Each action specifies both which position to fill and which word value to place there. The episode ends when all positions are filled. The reward is based on the minimum Hamming distance (computed at the bit level) between the completed sequence and a set of target "mode" sequences. :param word_size: Number of bits per word (e.g., 1 for single-bit actions). :param seq_size: Total number of bits in the sequence. Must be divisible by ``word_size``. :param n_modes: Number of target mode sequences. :param reward_exponent: Controls reward sharpness. Higher values make the reward more peaked around the modes. :param H: Optional tensor of shape ``(n_modes, seq_size)`` specifying the target modes in binary. If None, modes are generated randomly using block patterns. :param device_str: Device to use (``"cpu"`` or ``"cuda"``). :param seed: Random seed for mode generation. :param debug: If True, enable runtime guards (not compile-friendly). .. attribute:: word_size Number of bits per word. .. attribute:: seq_size Total number of bits. .. attribute:: words_per_seq Number of word positions (``seq_size // word_size``). .. attribute:: n_words Number of possible word values (``2 ** word_size``). .. attribute:: n_modes Number of target modes. .. attribute:: reward_exponent Reward sharpness parameter. .. attribute:: modes Target mode sequences as a binary tensor of shape ``(n_modes, seq_size)``. .. rubric:: Example >>> env = NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2) >>> # Action space: 4 positions * 2 word values + 1 exit = 9 actions >>> env.n_actions 9 >>> # State shape: 4 word positions >>> env.s0 tensor([-1, -1, -1, -1]) .. py:attribute:: H :value: None .. py:attribute:: States :type: type[NonAutoregressiveBitSequenceStates] .. py:method:: _decode_action(action) Decode a flat action index into (position, word) pair. :param action: Action tensor of shape ``(*batch_shape, 1)``. :returns: Tuple of (position, word) tensors, each of shape ``(*batch_shape, 1)``. .. py:method:: _integers_to_binary(tensor, k) :staticmethod: Convert a tensor of word integers to their binary representation. :param tensor: Integer tensor of shape ``(*batch_shape, words_per_seq)`` with values in ``{0, ..., 2^k - 1}``. :param k: Number of bits per word. :returns: Binary tensor of shape ``(*batch_shape, words_per_seq * k)`` with values in ``{0, 1}``. .. py:method:: _make_modes(seed, device) Generate target mode sequences in binary representation. If ``H`` is provided, it is used directly as the modes. Otherwise, modes are constructed by randomly combining 8-bit block patterns, following the procedure from the Trajectory Balance paper. :param seed: Random seed. :param device: Device to place the modes tensor on. :returns: Binary tensor of shape ``(n_modes, seq_size)`` with values in {0, 1}. .. py:method:: _min_hamming_distance(candidates, references) :staticmethod: Compute minimum Hamming distance from each candidate to any reference. :param candidates: Binary tensor of shape ``(*batch_shape, seq_size)``. :param references: Binary tensor of shape ``(n_refs, seq_size)``. :returns: Tensor of shape ``(*batch_shape,)`` with the minimum distance. .. py:method:: backward_step(states, actions) Undo a word placement by clearing the position back to -1. The backward action has the same encoding as the forward action: ``action = position * n_words + word``. The word component is used to identify which position to clear. :param states: Current states. :param actions: Backward actions to undo. :returns: Previous states with the specified positions cleared. .. py:method:: log_reward(final_states) Compute log-reward based on Hamming distance to nearest mode. The log-reward is: ``log R(x) = -reward_exponent * min_d(x, modes) / seq_size`` where ``min_d`` is the minimum bit-level Hamming distance between the completed sequence and any target mode. :param final_states: Terminal states with all positions filled. :returns: Log-reward tensor of shape ``(*batch_shape,)``. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Generate random partially-filled states. Each position is independently either unfilled (-1) or filled with a random word value. :param batch_shape: Shape of the batch. :param conditions: Optional conditions tensor. :param device: Device to use. :param debug: If True, enable debug mode. :returns: Random states. .. py:method:: make_states_class() Create the States class with environment-specific constants. .. py:attribute:: modes .. py:attribute:: n_modes_count :value: 2 .. py:property:: n_terminating_states :type: int Total number of possible terminal states. .. py:attribute:: n_words :value: 2 .. py:method:: reward(final_states) Compute reward as ``exp(log_reward)``. :param final_states: Terminal states. :returns: Reward tensor of shape ``(*batch_shape,)``. .. py:attribute:: reward_exponent :value: 2.0 .. py:attribute:: seq_size :value: 4 .. py:method:: step(states, actions) Place a word at the specified position. The action encodes ``(position, word)`` as a flat index: ``action = position * n_words + word``. :param states: Current states. :param actions: Actions encoding (position, word) pairs. :returns: Next states with the specified positions filled. .. py:property:: terminating_states :type: NonAutoregressiveBitSequenceStates Enumerate all terminal states (only feasible for small environments). .. py:method:: true_dist(condition=None) Compute the true reward distribution over all terminal states. .. py:attribute:: word_size :value: 1 .. py:attribute:: words_per_seq :value: 4 .. py:class:: NonAutoregressiveBitSequenceStates(tensor, conditions = None, device = None, debug = False) Bases: :py:obj:`gfn.states.DiscreteStates` States for the non-autoregressive BitSequence environment. Each state is a tensor of shape ``(words_per_seq,)`` where each element is either ``-1`` (unfilled) or a word value in ``{0, ..., n_words - 1}``. .. attribute:: word_size Number of bits per word. .. attribute:: words_per_seq Number of word positions in the sequence. .. attribute:: n_words Number of possible word values (``2 ** word_size``). .. py:method:: _compute_backward_masks() Compute which backward actions are valid at each state. A backward action ``(pos, word)`` is valid iff position ``pos`` currently holds that exact word value. :returns: Boolean tensor of shape ``(*batch_shape, n_actions - 1)``. .. py:method:: _compute_forward_masks() Compute which forward actions are valid at each state. An action ``(pos, word)`` is valid iff position ``pos`` is unfilled (value == -1). All ``n_words`` word choices for a given position share the same validity. The exit action is only valid when all positions are filled. :returns: Boolean tensor of shape ``(*batch_shape, n_actions)``. .. py:attribute:: n_words :type: ClassVar[int] .. py:method:: to_str() Convert states to human-readable binary strings. :returns: List of binary strings, one per state in the flattened batch. .. py:attribute:: word_size :type: ClassVar[int] .. py:attribute:: words_per_seq :type: ClassVar[int]