gfn.gym.discrete_ebm¶
Classes¶
Environment for discrete energy-based models. |
|
Base class for energy functions. |
|
Ising model energy function. |
Module Contents¶
- class gfn.gym.discrete_ebm.DiscreteEBM(ndim, energy=None, alpha=1.0, device='cpu', debug=False)¶
Bases:
gfn.env.DiscreteEnvEnvironment for discrete energy-based models.
This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf.
The states are represented as 1d tensors of length ndim with values in {-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, …, -1]. An action corresponds to replacing a -1 with a 0 or a 1. Action i in [0, ndim - 1] corresponds to replacing s[i] with 0. Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1. The last action is the exit action that is only available for complete states (those with no -1).
- Parameters:
ndim (int)
energy (EnergyFunction | None)
alpha (float)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- ndim¶
Dimension D of the sampling space {0, 1}^D.
- Type:
int
- energy¶
Energy function of the EBM.
- Type:
- alpha¶
Interaction strength the EBM.
- Type:
float
- States: type[gfn.states.DiscreteStates]¶
- property all_states: gfn.states.DiscreteStates¶
Returns all possible states of the environment.
- Return type:
- alpha = 1.0¶
- backward_step(states, actions)¶
Performs a backward step.
In this env, states are n-dim vectors. s0 is empty (represented as -1), so s0=[-1, -1, …, -1], each action is replacing a -1 with either a 0 or 1. Action i in [0, ndim-1] is replacing s[i] with 0, whereas action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. A backward action asks “what index should be set back to -1”, hence the fmod to enable wrapping of indices.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to be undone.
- Returns:
The previous states.
- Return type:
- energy: EnergyFunction = None¶
- get_states_indices(states)¶
Given that each state is of length ndim with values in {-1, 0, 1}, there are 3**ndim states, which we can label from 0 to 3**ndim - 1.
The easiest way to map each state to a unique integer is to consider the state as a number in base 3, where each digit can be in {0, 1, 2}. We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}.
- Parameters:
states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.
- Returns:
The states indices as tensor of shape (*batch_shape).
- Return type:
torch.Tensor
- get_terminating_states_indices(states)¶
Given that each terminating state is of length ndim with values in {0, 1}, there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1.
The easiest way to map each state to a unique integer is to consider the state as a number in base 2.
- Parameters:
states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.
- Returns:
The indices of the terminating states as tensor of shape (*batch_shape).
- Return type:
torch.Tensor
- is_exit_actions(actions)¶
Determines if the actions are exit actions.
- Parameters:
actions (torch.Tensor) – tensor of actions of shape (*batch_shape, *action_shape).
- Returns:
Tensor of booleans of shape (*batch_shape).
- Return type:
torch.Tensor
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
float
- log_reward(final_states)¶
The energy weighted by alpha is our log reward.
- Parameters:
final_states (gfn.states.DiscreteStates) – DiscreteStates object representing the final states.
- Returns:
The log reward as tensor of shape (*batch_shape).
- Return type:
torch.Tensor
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generates random states tensor of shape (*batch_shape, ndim).
- Parameters:
batch_shape (Tuple) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A DiscreteStates object with random states.
- Return type:
- make_states_class()¶
Returns the DiscreteStates class for the DiscreteEBM environment.
- Return type:
- property n_states: int¶
Returns the number of states in the environment.
- Return type:
int
- property n_terminating_states: int¶
Returns the number of terminating states in the environment.
- Return type:
int
- ndim¶
- reward(final_states)¶
Computes the reward for a batch of final states.
- Parameters:
final_states (gfn.states.DiscreteStates) – A batch of final states.
- Returns:
A tensor of rewards.
- Return type:
torch.Tensor
- step(states, actions)¶
Performs a step.
- Parameters:
states (gfn.states.States) – States object representing the current states.
actions (gfn.actions.Actions) – Actions object representing the actions to be taken.
- Returns:
The next states as a States object.
- Return type:
- property terminating_states: gfn.states.DiscreteStates¶
Returns all terminating states of the environment.
- Return type:
- true_dist(condition=None)¶
Returns the true probability mass function of the reward distribution.
- Return type:
torch.Tensor
- class gfn.gym.discrete_ebm.EnergyFunction¶
Bases:
torch.nn.Module,abc.ABCBase class for energy functions.
- abstract forward(states)¶
Forward pass of the energy function.
- Parameters:
states (torch.Tensor) – tensor of states of shape (*batch_shape, *state_shape).
- Returns:
Tensor of energies of shape (*batch_shape).
- Return type:
torch.Tensor
- class gfn.gym.discrete_ebm.IsingModel(J)¶
Bases:
EnergyFunctionIsing model energy function.
- Parameters:
J (torch.Tensor)
- J¶
Interaction matrix of shape (state_shape, state_shape).
- Type:
torch.Tensor
- J¶
- forward(states)¶
Forward pass of the ising model.
- Parameters:
states (torch.Tensor) – tensor of states of shape (*batch_shape, *state_shape).
- Returns:
Tensor of energies of shape (*batch_shape).
- Return type:
torch.Tensor
- linear¶