gfn.gym.bayesian_structure¶
Classes¶
Environment for incrementally building a directed acyclic graph (DAG) for |
Module Contents¶
- class gfn.gym.bayesian_structure.BayesianStructure(n_nodes, state_evaluator, device='cpu', debug=False)¶
Bases:
gfn.gym.graph_building.GraphBuildingEnvironment for incrementally building a directed acyclic graph (DAG) for Bayesian structure learning (Deleu et al., 2022).
The environment allows the following actions: - Adding edges between existing nodes with features - Terminating construction (EXIT)
- Parameters:
n_nodes (int) – Number of nodes in the graph.
state_evaluator (Callable[[gfn.states.GraphStates], torch.Tensor]) – Callable that computes rewards for final states. If None, uses default GCNConvEvaluator
device_str – Device to run computations on (‘cpu’ or ‘cuda’)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to undo.
- Returns:
The previous states.
- Return type:
- is_action_valid(states, actions, backward=False)¶
Check if actions are valid for the given states.
- Parameters:
states (gfn.states.GraphStates) – Current graph states.
actions (gfn.actions.GraphActions) – Actions to validate.
backward (bool) – Whether this is a backward step.
- Returns:
True if all actions are valid, False otherwise.
- Return type:
bool
- log_reward(final_states)¶
The environment’s reward given a state. This or log_reward must be implemented.
- Parameters:
final_states (gfn.states.GraphStates) – A batch of final states.
- Returns:
Tensor of shape “batch_shape” containing the rewards.
- Return type:
torch.Tensor
- make_actions_class()¶
Returns the GraphActions class for this environment.
- Returns:
A type of a subclass of GraphActions with environment-specific functionalities.
- Return type:
type[gfn.actions.GraphActions]
- make_random_states_tensor(batch_shape, conditions=None, device=None)¶
Makes a batch of random DAG states with fixed number of nodes.
- Parameters:
batch_shape (int | Tuple) – Shape of the batch dimensions.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to create the graph states on.
- Returns:
A PyG Batch object containing random DAG states.
- Return type:
- make_states_class()¶
Creates a GraphStates class for this environment.
- Return type:
type[gfn.states.GraphStates]
- n_actions¶
- n_nodes¶
- abstract reward(final_states)¶
The environment’s reward given a state.
- Parameters:
final_states (gfn.states.GraphStates) – A batch of final states.
- Returns:
A tensor of shape (batch_size,) containing the rewards.
- Return type:
torch.Tensor
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to take.
- Returns:
The next states.
- Return type: