gfn.gym.box

Classes

BoxPolar

Box environment with polar (norm-based) action validation.

Module Contents

class gfn.gym.box.BoxPolar(delta=0.1, R0=0.1, R1=0.5, R2=2.0, epsilon=0.0001, device='cpu', debug=False)

Bases: gfn.env.Env

Box environment with polar (norm-based) action validation.

Corresponds to the environment in Section 4.1 of https://arxiv.org/abs/2301.12594

Actions are 2D vectors whose L2 norm must equal delta (for non-s0 forward steps) or be at most delta (for the initial s0 step). Use with the polar estimators/distributions in box_polar_utils.py.

See also

BoxCartesian for a simpler per-dimension Cartesian variant.

Parameters:
  • delta (float)

  • R0 (float)

  • R1 (float)

  • R2 (float)

  • epsilon (float)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

delta

The step size.

R0

The base reward.

R1

The reward for being outside the first box.

R2

The reward for being inside the second box.

epsilon

A small value to avoid numerical issues.

device

The device to use.

Type:

Literal[“cpu”, “cuda”] | torch.device

Return type:

torch.device

R0 = 0.1
R1 = 0.5
R2 = 2.0
backward_step(states, actions)

Backward step function for the Box environment.

Parameters:
Returns:

The previous states as a States object.

Return type:

gfn.states.States

delta = 0.1
epsilon = 0.0001
is_action_valid(states, actions, backward=False)

Checks if the actions are valid (polar norm-based semantics).

For polar actions: - Forward from s0: norm(action) <= delta - Forward from non-s0: norm(action) == delta (within tolerance) - Backward: state - action >= 0 component-wise - Backward to s0: if norm(state) < delta, action must equal state

Parameters:
Returns:

True if the actions are valid, False otherwise.

Return type:

bool

log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

float

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generates random states tensor of shape (*batch_shape, 2).

Parameters:
  • batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A States object with random states.

Return type:

gfn.states.States

static norm(x)

Computes the L2 norm of the input tensor along the last dimension.

Parameters:

x (torch.Tensor) – Input tensor of shape (*batch_shape, 2).

Returns:

Normalized tensor of shape batch_shape.

Return type:

torch.Tensor

reward(final_states)

Reward is distance from the goal point.

Parameters:

final_states (gfn.states.States) – States object representing the final states.

Returns:

The reward tensor of shape batch_shape.

Return type:

torch.Tensor

step(states, actions)

Step function for the Box environment.

Parameters:
Returns:

The next states as a States object.

Return type:

gfn.states.States