gfn.gym.discrete_ebm ==================== .. py:module:: gfn.gym.discrete_ebm Classes ------- .. autoapisummary:: gfn.gym.discrete_ebm.DiscreteEBM gfn.gym.discrete_ebm.EnergyFunction gfn.gym.discrete_ebm.IsingModel Module Contents --------------- .. py:class:: DiscreteEBM(ndim, energy = None, alpha = 1.0, device = 'cpu', debug = False) Bases: :py:obj:`gfn.env.DiscreteEnv` Environment for discrete energy-based models. This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf. The states are represented as 1d tensors of length `ndim` with values in `{-1, 0, 1}`. `s0` is empty (represented as -1), so `s0=[-1, -1, ..., -1]`. An action corresponds to replacing a -1 with a 0 or a 1. Action `i` in `[0, ndim - 1]` corresponds to replacing `s[i]` with 0. Action `i` in `[ndim, 2 * ndim - 1]` corresponds to replacing `s[i - ndim]` with 1. The last action is the exit action that is only available for complete states (those with no -1). .. attribute:: ndim Dimension D of the sampling space `{0, 1}^D`. :type: int .. attribute:: energy Energy function of the EBM. :type: EnergyFunction .. attribute:: alpha Interaction strength the EBM. :type: float .. py:attribute:: States :type: type[gfn.states.DiscreteStates] .. py:property:: all_states :type: gfn.states.DiscreteStates Returns all possible states of the environment. .. py:attribute:: alpha :value: 1.0 .. py:method:: backward_step(states, actions) Performs a backward step. In this env, states are n-dim vectors. `s0` is empty (represented as -1), so `s0=[-1, -1, ..., -1]`, each action is replacing a -1 with either a 0 or 1. Action `i` in `[0, ndim-1]` is replacing `s[i]` with 0, whereas action `i` in `[ndim, 2*ndim-1]` corresponds to replacing `s[i - ndim]` with 1. A backward action asks "what index should be set back to -1", hence the fmod to enable wrapping of indices. :param states: The current states. :param actions: The actions to be undone. :returns: The previous states. .. py:attribute:: energy :type: EnergyFunction :value: None .. py:method:: get_states_indices(states) Given that each state is of length `ndim` with values in `{-1, 0, 1}`, there are `3**ndim` states, which we can label from `0` to `3**ndim - 1`. The easiest way to map each state to a unique integer is to consider the state as a number in base 3, where each digit can be in `{0, 1, 2}`. We thus need to shift this number by 1 so that `{-1, 0, 1} -> {0, 1, 2}`. :param states: DiscreteStates object representing the states. :returns: The states indices as tensor of shape `(*batch_shape)`. .. py:method:: get_terminating_states_indices(states) Given that each terminating state is of length `ndim` with values in `{0, 1}`, there are `2**ndim` terminating states, which we can label from `0` to `2**ndim - 1`. The easiest way to map each state to a unique integer is to consider the state as a number in base 2. :param states: DiscreteStates object representing the states. :returns: The indices of the terminating states as tensor of shape `(*batch_shape)`. .. py:method:: is_exit_actions(actions) Determines if the actions are exit actions. :param actions: tensor of actions of shape `(*batch_shape, *action_shape)`. :returns: Tensor of booleans of shape `(*batch_shape)`. .. py:method:: log_partition(condition=None) Returns the log partition of the reward function. .. py:method:: log_reward(final_states) The energy weighted by alpha is our log reward. :param final_states: DiscreteStates object representing the final states. :returns: The log reward as tensor of shape `(*batch_shape)`. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Generates random states tensor of shape `(*batch_shape, ndim)`. :param batch_shape: The shape of the batch. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `DiscreteStates` object with random states. .. py:method:: make_states_class() Returns the DiscreteStates class for the DiscreteEBM environment. .. py:property:: n_states :type: int Returns the number of states in the environment. .. py:property:: n_terminating_states :type: int Returns the number of terminating states in the environment. .. py:attribute:: ndim .. py:method:: reward(final_states) Computes the reward for a batch of final states. :param final_states: A batch of final states. :returns: A tensor of rewards. .. py:method:: step(states, actions) Performs a step. :param states: States object representing the current states. :param actions: Actions object representing the actions to be taken. :returns: The next states as a `States` object. .. py:property:: terminating_states :type: gfn.states.DiscreteStates Returns all terminating states of the environment. .. py:method:: true_dist(condition=None) Returns the true probability mass function of the reward distribution. .. py:class:: EnergyFunction Bases: :py:obj:`torch.nn.Module`, :py:obj:`abc.ABC` Base class for energy functions. .. py:method:: forward(states) :abstractmethod: Forward pass of the energy function. :param states: tensor of states of shape `(*batch_shape, *state_shape)`. :returns: Tensor of energies of shape `(*batch_shape)`. .. py:class:: IsingModel(J) Bases: :py:obj:`EnergyFunction` Ising model energy function. .. attribute:: J Interaction matrix of shape `(state_shape, state_shape)`. :type: torch.Tensor .. py:attribute:: J .. py:method:: forward(states) Forward pass of the ising model. :param states: tensor of states of shape `(*batch_shape, *state_shape)`. :returns: Tensor of energies of shape `(*batch_shape)`. .. py:attribute:: linear