gfn.gym.graph_building

Classes

GraphBuilding

Environment for incrementally building graphs.

GraphBuildingOnEdges

Environment for building graphs edge by edge with discrete action space.

Module Contents

class gfn.gym.graph_building.GraphBuilding(num_node_classes, num_edge_classes, state_evaluator, is_directed=True, max_nodes=None, device='cpu', s0=None, sf=None, debug=False)

Bases: gfn.env.GraphEnv

Environment for incrementally building graphs.

This environment allows constructing graphs by: - Adding nodes of a given class - Adding edges of a given class between existing nodes - Terminating construction (EXIT)

Parameters:
  • num_node_classes (int)

  • num_edge_classes (int)

  • state_evaluator (Callable[[gfn.states.GraphStates], torch.Tensor])

  • is_directed (bool)

  • max_nodes (int | None)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • s0 (torch_geometric.data.Data | None)

  • sf (torch_geometric.data.Data | None)

  • debug (bool)

num_node_classes

The number of node classes.

num_edge_classes

The number of edge classes.

state_evaluator

A callable that computes rewards for final states.

is_directed

Whether the graph is directed.

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.GraphStates

is_action_valid(states, actions, backward=False)

Check if actions are valid for the given states.

Parameters:
Returns:

True if all actions are valid, False otherwise.

Return type:

bool

make_actions_class()

Returns the GraphActions class for this environment.

Returns:

A type of a subclass of GraphActions with environment-specific functionalities.

Return type:

type[gfn.actions.GraphActions]

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generates random states.

Parameters:
  • batch_shape (Tuple) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A GraphStates object with random states.

Return type:

gfn.states.GraphStates

make_states_class()

Creates a GraphStates class for this environment.

Return type:

type[gfn.states.GraphStates]

max_nodes = None
reward(final_states)

The environment’s reward given a state.

Parameters:

final_states (gfn.states.GraphStates) – A batch of final states.

Returns:

A tensor of shape (batch_size,) containing the rewards.

Return type:

torch.Tensor

state_evaluator
step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.GraphStates

class gfn.gym.graph_building.GraphBuildingOnEdges(n_nodes, state_evaluator, directed, device, debug=False)

Bases: GraphBuilding

Environment for building graphs edge by edge with discrete action space.

The environment supports both directed and undirected graphs.

In each state, the policy can: 1. Add an edge between existing nodes. 2. Use the exit action to terminate graph building.

The action space is discrete, with size: - For directed graphs: n_nodes^2 - n_nodes + 1 (all possible directed edges + exit). - For undirected graphs: (n_nodes^2 - n_nodes)/2 + 1 (upper triangle + exit).

Parameters:
  • n_nodes (int)

  • state_evaluator (callable)

  • directed (bool)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

n_nodes

The number of nodes in the graph.

Type:

int

n_possible_edges

The number of possible edges.

Type:

int

is_action_valid(states, actions, backward=False)

Checks if the actions are valid.

Parameters:
Returns:

True if the actions are valid, False otherwise.

Return type:

bool

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Makes a batch of random graph states with fixed number of nodes.

Parameters:
  • batch_shape (Tuple) – Shape of the batch dimensions.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A GraphStates object containing random graph states.

Return type:

gfn.states.GraphStates

make_states_class()

Creates a GraphStates class for this environment.

Return type:

type[gfn.states.GraphStates]

n_nodes