gfn.gym.box¶
Classes¶
Box environment with polar (norm-based) action validation. |
Module Contents¶
- class gfn.gym.box.BoxPolar(delta=0.1, R0=0.1, R1=0.5, R2=2.0, epsilon=0.0001, device='cpu', debug=False)¶
Bases:
gfn.env.EnvBox environment with polar (norm-based) action validation.
Corresponds to the environment in Section 4.1 of https://arxiv.org/abs/2301.12594
Actions are 2D vectors whose L2 norm must equal delta (for non-s0 forward steps) or be at most delta (for the initial s0 step). Use with the polar estimators/distributions in
box_polar_utils.py.See also
BoxCartesianfor a simpler per-dimension Cartesian variant.- Parameters:
delta (float)
R0 (float)
R1 (float)
R2 (float)
epsilon (float)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- delta¶
The step size.
- R0¶
The base reward.
- R1¶
The reward for being outside the first box.
- R2¶
The reward for being inside the second box.
- epsilon¶
A small value to avoid numerical issues.
- device¶
The device to use.
- Type:
Literal[“cpu”, “cuda”] | torch.device
- Return type:
torch.device
- R0 = 0.1¶
- R1 = 0.5¶
- R2 = 2.0¶
- backward_step(states, actions)¶
Backward step function for the Box environment.
- Parameters:
states (gfn.states.States) – States object representing the current states.
actions (gfn.actions.Actions) – Actions object representing the actions to be taken.
- Returns:
The previous states as a States object.
- Return type:
- delta = 0.1¶
- epsilon = 0.0001¶
- is_action_valid(states, actions, backward=False)¶
Checks if the actions are valid (polar norm-based semantics).
For polar actions: - Forward from s0: norm(action) <= delta - Forward from non-s0: norm(action) == delta (within tolerance) - Backward: state - action >= 0 component-wise - Backward to s0: if norm(state) < delta, action must equal state
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to be taken.
backward (bool) – Whether the actions are backward actions.
- Returns:
True if the actions are valid, False otherwise.
- Return type:
bool
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
float
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generates random states tensor of shape (*batch_shape, 2).
- Parameters:
batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A States object with random states.
- Return type:
- static norm(x)¶
Computes the L2 norm of the input tensor along the last dimension.
- Parameters:
x (torch.Tensor) – Input tensor of shape (*batch_shape, 2).
- Returns:
Normalized tensor of shape batch_shape.
- Return type:
torch.Tensor
- reward(final_states)¶
Reward is distance from the goal point.
- Parameters:
final_states (gfn.states.States) – States object representing the final states.
- Returns:
The reward tensor of shape batch_shape.
- Return type:
torch.Tensor
- step(states, actions)¶
Step function for the Box environment.
- Parameters:
states (gfn.states.States) – States object representing the current states.
actions (gfn.actions.Actions) – Actions object representing the actions to be taken.
- Returns:
The next states as a States object.
- Return type: