gfn.gym.discrete_ebm

Classes

DiscreteEBM

Environment for discrete energy-based models.

EnergyFunction

Base class for energy functions.

IsingModel

Ising model energy function.

Module Contents

class gfn.gym.discrete_ebm.DiscreteEBM(ndim, energy=None, alpha=1.0, device='cpu', debug=False)

Bases: gfn.env.DiscreteEnv

Environment for discrete energy-based models.

This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf.

The states are represented as 1d tensors of length ndim with values in {-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, …, -1]. An action corresponds to replacing a -1 with a 0 or a 1. Action i in [0, ndim - 1] corresponds to replacing s[i] with 0. Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1. The last action is the exit action that is only available for complete states (those with no -1).

Parameters:
  • ndim (int)

  • energy (EnergyFunction | None)

  • alpha (float)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

ndim

Dimension D of the sampling space {0, 1}^D.

Type:

int

energy

Energy function of the EBM.

Type:

EnergyFunction

alpha

Interaction strength the EBM.

Type:

float

States: type[gfn.states.DiscreteStates]
property all_states: gfn.states.DiscreteStates

Returns all possible states of the environment.

Return type:

gfn.states.DiscreteStates

alpha = 1.0
backward_step(states, actions)

Performs a backward step.

In this env, states are n-dim vectors. s0 is empty (represented as -1), so s0=[-1, -1, …, -1], each action is replacing a -1 with either a 0 or 1. Action i in [0, ndim-1] is replacing s[i] with 0, whereas action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. A backward action asks “what index should be set back to -1”, hence the fmod to enable wrapping of indices.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.States

energy: EnergyFunction = None
get_states_indices(states)

Given that each state is of length ndim with values in {-1, 0, 1}, there are 3**ndim states, which we can label from 0 to 3**ndim - 1.

The easiest way to map each state to a unique integer is to consider the state as a number in base 3, where each digit can be in {0, 1, 2}. We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}.

Parameters:

states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.

Returns:

The states indices as tensor of shape (*batch_shape).

Return type:

torch.Tensor

get_terminating_states_indices(states)

Given that each terminating state is of length ndim with values in {0, 1}, there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1.

The easiest way to map each state to a unique integer is to consider the state as a number in base 2.

Parameters:

states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.

Returns:

The indices of the terminating states as tensor of shape (*batch_shape).

Return type:

torch.Tensor

is_exit_actions(actions)

Determines if the actions are exit actions.

Parameters:

actions (torch.Tensor) – tensor of actions of shape (*batch_shape, *action_shape).

Returns:

Tensor of booleans of shape (*batch_shape).

Return type:

torch.Tensor

log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

float

log_reward(final_states)

The energy weighted by alpha is our log reward.

Parameters:

final_states (gfn.states.DiscreteStates) – DiscreteStates object representing the final states.

Returns:

The log reward as tensor of shape (*batch_shape).

Return type:

torch.Tensor

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generates random states tensor of shape (*batch_shape, ndim).

Parameters:
  • batch_shape (Tuple) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A DiscreteStates object with random states.

Return type:

gfn.states.DiscreteStates

make_states_class()

Returns the DiscreteStates class for the DiscreteEBM environment.

Return type:

type[gfn.states.DiscreteStates]

property n_states: int

Returns the number of states in the environment.

Return type:

int

property n_terminating_states: int

Returns the number of terminating states in the environment.

Return type:

int

ndim
reward(final_states)

Computes the reward for a batch of final states.

Parameters:

final_states (gfn.states.DiscreteStates) – A batch of final states.

Returns:

A tensor of rewards.

Return type:

torch.Tensor

step(states, actions)

Performs a step.

Parameters:
Returns:

The next states as a States object.

Return type:

gfn.states.States

property terminating_states: gfn.states.DiscreteStates

Returns all terminating states of the environment.

Return type:

gfn.states.DiscreteStates

true_dist(condition=None)

Returns the true probability mass function of the reward distribution.

Return type:

torch.Tensor

class gfn.gym.discrete_ebm.EnergyFunction

Bases: torch.nn.Module, abc.ABC

Base class for energy functions.

abstract forward(states)

Forward pass of the energy function.

Parameters:

states (torch.Tensor) – tensor of states of shape (*batch_shape, *state_shape).

Returns:

Tensor of energies of shape (*batch_shape).

Return type:

torch.Tensor

class gfn.gym.discrete_ebm.IsingModel(J)

Bases: EnergyFunction

Ising model energy function.

Parameters:

J (torch.Tensor)

J

Interaction matrix of shape (state_shape, state_shape).

Type:

torch.Tensor

J
forward(states)

Forward pass of the ising model.

Parameters:

states (torch.Tensor) – tensor of states of shape (*batch_shape, *state_shape).

Returns:

Tensor of energies of shape (*batch_shape).

Return type:

torch.Tensor

linear