gfn.gym.graph_building¶
Classes¶
Environment for incrementally building graphs. |
|
Environment for building graphs edge by edge with discrete action space. |
Module Contents¶
- class gfn.gym.graph_building.GraphBuilding(num_node_classes, num_edge_classes, state_evaluator, is_directed=True, max_nodes=None, device='cpu', s0=None, sf=None, debug=False)¶
Bases:
gfn.env.GraphEnvEnvironment for incrementally building graphs.
This environment allows constructing graphs by: - Adding nodes of a given class - Adding edges of a given class between existing nodes - Terminating construction (EXIT)
- Parameters:
num_node_classes (int)
num_edge_classes (int)
state_evaluator (Callable[[gfn.states.GraphStates], torch.Tensor])
is_directed (bool)
max_nodes (int | None)
device (Literal['cpu', 'cuda'] | torch.device)
s0 (torch_geometric.data.Data | None)
sf (torch_geometric.data.Data | None)
debug (bool)
- num_node_classes¶
The number of node classes.
- num_edge_classes¶
The number of edge classes.
- state_evaluator¶
A callable that computes rewards for final states.
- is_directed¶
Whether the graph is directed.
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to undo.
- Returns:
The previous states.
- Return type:
- is_action_valid(states, actions, backward=False)¶
Check if actions are valid for the given states.
- Parameters:
states (gfn.states.GraphStates) – Current graph states.
actions (gfn.actions.GraphActions) – Actions to validate.
backward (bool) – Whether this is a backward step.
- Returns:
True if all actions are valid, False otherwise.
- Return type:
bool
- make_actions_class()¶
Returns the GraphActions class for this environment.
- Returns:
A type of a subclass of GraphActions with environment-specific functionalities.
- Return type:
type[gfn.actions.GraphActions]
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generates random states.
- Parameters:
batch_shape (Tuple) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A GraphStates object with random states.
- Return type:
- make_states_class()¶
Creates a GraphStates class for this environment.
- Return type:
type[gfn.states.GraphStates]
- max_nodes = None¶
- reward(final_states)¶
The environment’s reward given a state.
- Parameters:
final_states (gfn.states.GraphStates) – A batch of final states.
- Returns:
A tensor of shape (batch_size,) containing the rewards.
- Return type:
torch.Tensor
- state_evaluator¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to take.
- Returns:
The next states.
- Return type:
- class gfn.gym.graph_building.GraphBuildingOnEdges(n_nodes, state_evaluator, directed, device, debug=False)¶
Bases:
GraphBuildingEnvironment for building graphs edge by edge with discrete action space.
The environment supports both directed and undirected graphs.
In each state, the policy can: 1. Add an edge between existing nodes. 2. Use the exit action to terminate graph building.
The action space is discrete, with size: - For directed graphs: n_nodes^2 - n_nodes + 1 (all possible directed edges + exit). - For undirected graphs: (n_nodes^2 - n_nodes)/2 + 1 (upper triangle + exit).
- Parameters:
n_nodes (int)
state_evaluator (callable)
directed (bool)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- n_nodes¶
The number of nodes in the graph.
- Type:
int
- n_possible_edges¶
The number of possible edges.
- Type:
int
- is_action_valid(states, actions, backward=False)¶
Checks if the actions are valid.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to validate.
backward (bool) – Whether the actions are backward actions.
- Returns:
True if the actions are valid, False otherwise.
- Return type:
bool
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Makes a batch of random graph states with fixed number of nodes.
- Parameters:
batch_shape (Tuple) – Shape of the batch dimensions.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A GraphStates object containing random graph states.
- Return type:
- make_states_class()¶
Creates a GraphStates class for this environment.
- Return type:
type[gfn.states.GraphStates]
- n_nodes¶