gfn.env¶
Attributes¶
Classes¶
Base class for discrete environments, where states are defined in a discrete |
|
Base class for all environments. |
|
Base class for graph-based environments. |
Module Contents¶
- class gfn.env.DiscreteEnv(n_actions, s0, state_shape, dummy_action=None, exit_action=None, sf=None, debug=False)¶
Bases:
Env,abc.ABCBase class for discrete environments, where states are defined in a discrete space, and actions are represented by an integer in {0, …, n_actions - 1}, the last one being the exit action.
For a guide on creating your own environments, see the documentation at: guides/creating_environments.
For a complete example, see the HyperGrid environment in src/gfn/gym/hypergrid.py.
- Parameters:
n_actions (int)
s0 (torch.Tensor)
state_shape (Tuple | int)
dummy_action (Optional[torch.Tensor])
exit_action (Optional[torch.Tensor])
sf (Optional[torch.Tensor])
debug (bool)
- n_actions¶
The number of actions in the environment.
- state_shape¶
Tuple representing the shape of the states.
- dummy_action¶
Tensor of shape (1,) representing the dummy action.
- exit_action¶
Tensor of shape (1,) representing the exit action.
- States¶
The States class associated with this environment.
- Actions¶
The Actions class associated with this environment.
- is_discrete¶
Class variable, whether the environment is discrete.
- _add_logz_diff(validation_info, gflownet, validate_condition)¶
Compute |learned_logZ - true_logZ| and add to validation_info.
- Parameters:
validation_info (Dict[str, float])
gflownet (gfn.gflownet.GFlowNet)
validate_condition (torch.Tensor | None)
- Return type:
None
- _backward_step(states, actions)¶
Wrapper for the user-defined backward_step function.
- Parameters:
states (gfn.states.DiscreteStates) – The batch of discrete states.
actions (gfn.actions.Actions) – The batch of actions.
- Returns:
The batch of previous discrete states.
- Return type:
- static _jsd(p, q)¶
Jensen-Shannon divergence between two discrete distributions.
Uses the convention 0 * log(0 / x) = 0 via masking so that zero- probability bins contribute nothing (no clamping, no mass distortion).
- Parameters:
p (torch.Tensor) – First distribution (1-D, sums to ~1).
q (torch.Tensor) – Second distribution (1-D, sums to ~1).
- Returns:
JSD in nats (base-e). Bounded in [0, ln(2)].
- Return type:
float
- _step(states, actions)¶
Wrapper for the user-defined step function.
- Parameters:
states (gfn.states.DiscreteStates) – The batch of discrete states.
actions (gfn.actions.Actions) – The batch of actions.
- Returns:
The batch of next discrete states.
- Return type:
- _warn_if_insufficient_samples(n_validation_samples)¶
Emit a warning if validation sample count is too low for the state space.
- Parameters:
n_validation_samples (int)
- Return type:
None
- property all_states: gfn.states.DiscreteStates¶
- Abstractmethod:
- Return type:
Optional method to return a batch of all discrete states in the environment.
- Returns:
A batch of all discrete states (batch_shape = (n_states,)).
- Return type:
Note
self.get_states_indices(self.all_states) and torch.arange(self.n_states) should be equivalent.
- abstract get_states_indices(states)¶
Optional method to return the indices of the states in the environment.
- Parameters:
states (gfn.states.DiscreteStates) – The batch of states.
- Returns:
Tensor of shape (*batch_shape) containing the indices of the states.
- Return type:
torch.Tensor
- get_terminating_state_dist(states)¶
Computes the empirical distribution over terminating states.
Uses vectorized
scatter_add_for efficient histogram computation.- Parameters:
states (gfn.states.DiscreteStates) – A batch of terminating
DiscreteStates.- Returns:
A 1D CPU tensor of shape
(n_terminating_states,)with empirical frequencies summing to 1.- Raises:
NotImplementedError – If the environment lacks
get_terminating_states_indicesorn_terminating_states.ValueError – If states is empty.
- Return type:
torch.Tensor
- abstract get_terminating_states_indices(states)¶
- Optional method to return the indices of the terminating states in the
environment.
- Parameters:
states (gfn.states.DiscreteStates) – The batch of states.
- Returns:
Tensor of shape (*batch_shape) containing the indices of the terminating states.
- Return type:
torch.Tensor
- is_action_valid(states, actions, backward=False)¶
Checks whether the actions are valid in the given discrete states.
- Parameters:
states (gfn.states.DiscreteStates) – The batch of discrete states.
actions (gfn.actions.Actions) – The batch of actions.
backward (bool) – If True, checks validity for backward actions.
- Returns:
True if all actions are valid in the given states, False otherwise. When debug is False, returns True without checking to keep hot paths compile-friendly.
- Return type:
bool
- is_discrete: bool = True¶
- make_actions_class()¶
Returns the Actions class for this environment.
- Returns:
A type of a subclass of Actions with environment-specific functionalities.
- Return type:
type[gfn.actions.Actions]
- make_states_class()¶
Returns the DiscreteStates class for this environment.
- Returns:
A type of a subclass of DiscreteStates with environment-specific functionalities.
- Return type:
- n_actions¶
- property n_states: int¶
- Abstractmethod:
- Return type:
int
Optional method to return the number of states in the environment.
- Returns:
The number of states.
- Return type:
int
- property n_terminating_states: int¶
- Abstractmethod:
- Return type:
int
Optional method to return the number of terminating states in the environment.
- Returns:
The number of terminating states.
- Return type:
int
- reset(batch_shape, random=False, sink=False, seed=None, conditions=None)¶
Instantiates a batch of random, initial, or sink states.
- Parameters:
batch_shape (int | Tuple[int, Ellipsis]) – Shape of the batch (int or tuple).
random (bool) – If True, initialize states randomly.
sink (bool) – If True, initialize states as sink states (\(s_f\)).
seed (Optional[int]) – (Optional) Random seed for reproducibility.
conditions (Optional[torch.Tensor]) – (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions.
- Returns:
A batch of initial or sink states.
- Return type:
- s0: torch.Tensor¶
- sf: torch.Tensor¶
- states_from_batch_shape(batch_shape, random=False, sink=False, conditions=None)¶
Returns a batch of random, initial, or sink states with a given batch shape.
- Parameters:
batch_shape (int | Tuple[int, Ellipsis]) – Tuple representing the shape of the batch of states.
random (bool) – If True, initialize states randomly.
sink (bool) – If True, initialize states as sink states (\(s_f\)).
conditions (torch.Tensor | None)
- Returns:
A batch of random, initial, or sink states.
- Return type:
- states_from_tensor(tensor, conditions=None)¶
Wraps the supplied tensor in a DiscreteStates instance.
- Parameters:
- Returns:
An instance of DiscreteStates.
- Return type:
- property terminating_states: gfn.states.DiscreteStates¶
- Abstractmethod:
- Return type:
Optional method to return a batch of all terminating states in the environment.
- Returns:
A batch of all terminating states (batch_shape = (n_terminating_states,)).
- Return type:
Note
self.get_terminating_states_indices(self.terminating_states) and torch.arange(self.n_terminating_states) should be equivalent.
- validate(gflownet, n_validation_samples=1000, visited_terminating_states=None, validate_condition=None, sampling_chunk_size=5000, check_sample_sufficiency=True)¶
Evaluate a GFlowNet against this environment’s true distribution.
Always samples fresh from the current policy to produce an unbiased estimate. Computes L1 distance and Jensen-Shannon divergence between the empirical and true distributions. If the GFlowNet has a learned
logZand the environment implementslog_partition, also reports the absolute difference.- Parameters:
gflownet (gfn.gflownet.GFlowNet) – The GFlowNet to evaluate.
n_validation_samples (int) – Number of fresh trajectories to sample.
visited_terminating_states (Optional[gfn.states.DiscreteStates]) – Deprecated. Ignored if passed; a
DeprecationWarningis emitted.validate_condition (torch.Tensor | None) – Optional condition tensor for conditional envs.
sampling_chunk_size (int) – Max trajectories to sample at once (avoids OOM).
check_sample_sufficiency (bool) – If True, emits a one-time warning when
n_validation_samplesis too small relative to the state space. Set False to suppress.
- Returns:
(metrics_dict, sampled_terminating_states)where metrics_dict contains"l1_dist","jsd", and optionally"logZ_diff".- Raises:
ValueError – If
true_distis unavailable,n_validation_samplesis non-positive, or enumeration APIs are missing.- Return type:
Tuple[Dict[str, float], gfn.states.DiscreteStates | None]
- class gfn.env.Env(s0, state_shape, action_shape, dummy_action, exit_action, sf=None, debug=False)¶
Bases:
abc.ABCBase class for all environments.
Environments define the state and action spaces, as well as the forward & backward transition and reward functions.
- Parameters:
s0 (torch.Tensor | torch_geometric.data.Data)
state_shape (Tuple)
action_shape (Tuple)
dummy_action (torch.Tensor)
exit_action (torch.Tensor)
sf (Optional[torch.Tensor | torch_geometric.data.Data])
debug (bool)
- s0¶
The initial state (tensor or GeometricData).
- sf¶
The sink (final) state (tensor or GeometricData).
- state_shape¶
Tuple representing the shape of the states.
- action_shape¶
Tuple representing the shape of the actions.
- dummy_action¶
Tensor representing the dummy action for padding.
- exit_action¶
Tensor representing the exit action.
- States¶
The States class associated with this environment.
- Actions¶
The Actions class associated with this environment.
- is_discrete¶
Class variable, whether the environment is discrete.
- is_conditional¶
Class variable, whether the environment is conditional.
- Actions¶
- States¶
- _backward_step(states, actions)¶
Wrapper for the user-defined backward_step function.
This wrapper ensures that the backward_step is called only on valid states and actions, and sets the states to the initial state when the action is not valid. It also ensures that the new states are a distinct object from the old states.
- Parameters:
states (gfn.states.States) – A batch of states.
actions (gfn.actions.Actions) – A batch of actions.
- Returns:
A batch of previous states.
- Return type:
- _step(states, actions)¶
Wrapper for the user-defined step function.
This wrapper ensures that the step is called only on valid states and actions, and sets the states to the sink state when the action is exit. It also ensures that the new states are a distinct object from the old states.
- Parameters:
states (gfn.states.States) – A batch of states.
actions (gfn.actions.Actions) – A batch of actions.
- Returns:
A batch of next states.
- Return type:
- action_shape¶
- actions_from_batch_shape(batch_shape)¶
Returns a batch of dummy actions with the supplied batch shape.
- Parameters:
batch_shape (Tuple) – Tuple representing the shape of the batch of actions.
- Returns:
A batch of dummy actions.
- Return type:
- actions_from_tensor(tensor)¶
Wraps the supplied tensor in an Actions instance.
- Parameters:
tensor (torch.Tensor) – Tensor of shape (*action_shape) representing the actions.
- Returns:
An Actions instance.
- Return type:
- abstract backward_step(states, actions)¶
Backward transition function of the environment.
This method takes a batch of states and actions and returns a batch of previous states. It does not need to check whether the actions are valid or the states are sink states, because the _backward_step method wraps it and checks for validity.
- Parameters:
states (gfn.states.States) – A batch of states.
actions (gfn.actions.Actions) – A batch of actions.
- Returns:
A batch of previous states.
- Return type:
- debug = False¶
- property device: torch.device¶
The device on which the environment’s elements are stored.
- Returns:
The device of the initial state.
- Return type:
torch.device
- dummy_action¶
- exit_action¶
- abstract is_action_valid(states, actions, backward=False)¶
Checks whether the actions are valid in the given states.
- Parameters:
states (gfn.states.States) – A batch of states.
actions (gfn.actions.Actions) – A batch of actions.
backward (bool) – If True, checks validity for backward actions.
- Returns:
True if all actions are valid in the given states, False otherwise.
- Return type:
bool
- is_conditional: bool = False¶
- is_discrete: bool = False¶
- abstract log_partition(condition=None)¶
Optional method to return the logarithm of the partition function.
- Parameters:
condition (torch.Tensor | None) – Optional tensor of shape (condition_dim,) containing the condition.
- Returns:
The log partition function.
- Return type:
float
- log_reward(states)¶
Returns the environment’s log of rewards for a batch of states.
This or reward must be implemented by the environment.
- Parameters:
states (gfn.states.States) – A batch of states with a batch_shape.
- Returns:
Tensor of shape (*batch_shape) containing the log rewards.
- Return type:
torch.Tensor
- make_actions_class()¶
Returns the Actions class for this environment.
Defines a custom Actions class that inherits from Actions and implements assumed methods. The make_actions_class method should be overwritten to achieve more environment-specific Actions functionalities.
- Returns:
A type of a subclass of Actions with environment-specific functionalities.
- Return type:
type[gfn.actions.Actions]
- abstract make_random_states(batch_shape, conditions=None, device=None)¶
Optional method to return a batch of random states.
- Parameters:
batch_shape (Tuple) – Tuple representing the shape of the batch of states.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to create the states on.
- Returns:
A batch of random states.
- Return type:
- make_states_class()¶
Returns the States class for this environment.
Defines a custom States class that inherits from States and implements assumed methods. The make_states_class method should be overwritten to achieve more environment-specific States functionalities.
- Returns:
A type of a subclass of States with environment-specific functionalities.
- Return type:
type[gfn.states.States]
- reset(batch_shape, random=False, sink=False, seed=None, conditions=None)¶
Instantiates a batch of random, initial, or sink states.
- Parameters:
batch_shape (int | Tuple[int, Ellipsis]) – Shape of the batch (int, tuple, or list).
random (bool) – If True, initialize states randomly.
sink (bool) – If True, initialize states as sink states (\(s_f\)).
seed (Optional[int]) – (Optional) Random seed for reproducibility.
conditions (Optional[torch.Tensor]) – (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions.
- Returns:
A batch of initial or sink states.
- Return type:
- abstract reward(states)¶
Returns the environment’s rewards for a batch of states.
This or log_reward must be implemented by the environment.
- Parameters:
states (gfn.states.States) – A batch of states with a batch_shape.
- Returns:
Tensor of shape (*batch_shape) containing the rewards.
- Return type:
torch.Tensor
- s0¶
- abstract sample_conditions(batch_shape)¶
Sample conditions for the environment. Required for conditional environments.
- Parameters:
batch_shape (int | Tuple[int, Ellipsis]) – The shape of the batch of conditions to sample.
- Returns:
A tensor of shape (*batch_shape, condition_dim) containing the conditions.
- Return type:
torch.Tensor
- sf¶
- state_shape¶
- states_from_batch_shape(batch_shape, random=False, sink=False, conditions=None)¶
Returns a batch of random, initial, or sink states with a given batch shape.
- Parameters:
batch_shape (int | Tuple[int, Ellipsis]) – Tuple representing the shape of the batch of states.
random (bool) – If True, initialize states randomly (requires implementation).
sink (bool) – If True, initialize states as sink states (\(s_f\)).
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
- Returns:
A batch of random, initial, or sink states.
- Return type:
- states_from_tensor(tensor, conditions=None)¶
Wraps the supplied tensor in a States instance.
- Parameters:
- Returns:
A States instance.
- Return type:
- abstract step(states, actions)¶
Forward transition function of the environment.
This method takes a batch of states and actions and returns a batch of next states. It does not need to check whether the actions are valid or the states are sink states, because the _step method wraps it and checks for validity.
- Parameters:
states (gfn.states.States) – A batch of states.
actions (gfn.actions.Actions) – A batch of actions.
- Returns:
A batch of next states.
- Return type:
- abstract true_dist(condition=None)¶
Optional method to return the true distribution.
- Parameters:
condition (torch.Tensor | None) – Optional tensor of shape (condition_dim,) containing the condition.
- Returns:
The true distribution as a 1-dimensional tensor.
- Return type:
torch.Tensor
- class gfn.env.GraphEnv(s0, sf, num_node_classes, num_edge_classes, is_directed, debug=False)¶
Bases:
EnvBase class for graph-based environments.
Graph environments represent states as graphs (torch_geometric Data objects) and actions as graph modifications.
- Parameters:
s0 (torch_geometric.data.Data)
sf (torch_geometric.data.Data)
num_node_classes (int)
num_edge_classes (int)
is_directed (bool)
debug (bool)
- s0¶
GeometricData representing the initial graph state.
- sf¶
GeometricData representing the sink (final) graph state.
- num_node_classes¶
Number of node classes.
- num_edge_classes¶
Number of edge classes.
- is_directed¶
Whether the graph is directed.
- States¶
The States class associated with this environment.
- Actions¶
The Actions class associated with this environment.
- Actions¶
- States¶
- abstract backward_step(states, actions)¶
Backward transition function of the graph environment.
This method takes a batch of graph states and actions and returns a batch of previous graph states.
- Parameters:
states (gfn.states.GraphStates) – A batch of graph states.
actions (gfn.actions.GraphActions) – A batch of graph actions.
- Returns:
A batch of previous graph states.
- Return type:
- debug = False¶
- property device: torch.device¶
The device on which the graph states are stored.
- Returns:
The device of the initial graph state’s node features.
- Return type:
torch.device
- dummy_action¶
- exit_action¶
- is_directed¶
- make_actions_class()¶
Returns the GraphActions class for this environment.
- Returns:
A type of a subclass of GraphActions with environment-specific functionalities.
- Return type:
type[gfn.actions.GraphActions]
- abstract make_random_states(batch_shape, conditions=None, device=None)¶
Optional method to return a batch of random graph states.
- Parameters:
batch_shape (int | Tuple) – Shape of the batch (int or tuple).
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to create the graph states on.
- Returns:
A batch of random graph states.
- Return type:
- make_states_class()¶
Returns the GraphStates class for this environment.
- Returns:
A type of a subclass of GraphStates with environment-specific functionalities.
- Return type:
type[gfn.states.GraphStates]
- num_edge_classes¶
- num_node_classes¶
- reset(batch_shape, random=False, sink=False, seed=None, conditions=None)¶
Instantiates a batch of random, initial, or sink graph states.
- Parameters:
batch_shape (int | Tuple[int, Ellipsis]) – Shape of the batch (int or tuple).
random (bool) – If True, initialize states randomly.
sink (bool) – If True, initialize states as sink states (\(s_f\)).
seed (Optional[int]) – (Optional) Random seed for reproducibility.
conditions (Optional[torch.Tensor]) – (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions.
- Returns:
A batch of random, initial, or sink graph states.
- Return type:
- s0: torch_geometric.data.Data¶
- sf: torch_geometric.data.Data¶
- abstract step(states, actions)¶
Forward transition function of the graph environment.
This method takes a batch of graph states and actions and returns a batch of next graph states.
- Parameters:
states (gfn.states.GraphStates) – A batch of graph states.
actions (gfn.actions.GraphActions) – A batch of graph actions.
- Returns:
A batch of next graph states.
- Return type:
- gfn.env.NonValidActionsError¶