gfn.gym.bayesian_structure

Classes

BayesianStructure

Environment for incrementally building a directed acyclic graph (DAG) for

Module Contents

class gfn.gym.bayesian_structure.BayesianStructure(n_nodes, state_evaluator, device='cpu', debug=False)

Bases: gfn.gym.graph_building.GraphBuilding

Environment for incrementally building a directed acyclic graph (DAG) for Bayesian structure learning (Deleu et al., 2022).

The environment allows the following actions: - Adding edges between existing nodes with features - Terminating construction (EXIT)

Parameters:
  • n_nodes (int) – Number of nodes in the graph.

  • state_evaluator (Callable[[gfn.states.GraphStates], torch.Tensor]) – Callable that computes rewards for final states. If None, uses default GCNConvEvaluator

  • device_str – Device to run computations on (‘cpu’ or ‘cuda’)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.GraphStates

is_action_valid(states, actions, backward=False)

Check if actions are valid for the given states.

Parameters:
Returns:

True if all actions are valid, False otherwise.

Return type:

bool

log_reward(final_states)

The environment’s reward given a state. This or log_reward must be implemented.

Parameters:

final_states (gfn.states.GraphStates) – A batch of final states.

Returns:

Tensor of shape “batch_shape” containing the rewards.

Return type:

torch.Tensor

make_actions_class()

Returns the GraphActions class for this environment.

Returns:

A type of a subclass of GraphActions with environment-specific functionalities.

Return type:

type[gfn.actions.GraphActions]

make_random_states_tensor(batch_shape, conditions=None, device=None)

Makes a batch of random DAG states with fixed number of nodes.

Parameters:
  • batch_shape (int | Tuple) – Shape of the batch dimensions.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to create the graph states on.

Returns:

A PyG Batch object containing random DAG states.

Return type:

gfn.states.GraphStates

make_states_class()

Creates a GraphStates class for this environment.

Return type:

type[gfn.states.GraphStates]

n_actions
n_nodes
abstract reward(final_states)

The environment’s reward given a state.

Parameters:

final_states (gfn.states.GraphStates) – A batch of final states.

Returns:

A tensor of shape (batch_size,) containing the rewards.

Return type:

torch.Tensor

step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.GraphStates