gfn.gym.helpers.bayesian_structure.jsd

The code is adapted from: https://github.com/GFNOrg/GFN_vs_HVI/blob/master/dags/dag_gflownet/utils/exhaustive.py

Attributes

NUM_DAGS

Classes

FullPosterior

GraphCollection

Functions

all_dags(env, num_variables[, nodelist])

construct_state_dag_with_bfs(gflownet_cache, nodelist)

Constructs the state-action space of the GFlowNet.

get_children(graph, gfn_cache, nodelist)

Gets all the children of a graph.

get_full_posterior(scorer, data, env, nodelist[, verbose])

get_gflownet_cache(env, estimator, nodelist[, batch_size])

Cache the results of the GFlowNet for all the states.

get_gfn_exact_posterior(gfn_state_graph[, verbose])

get_markov_blanket(graph, node)

get_markov_blanket_graph(graph)

Build an undirected graph where two nodes are connected if

get_valid_actions(graph)

Gets the list of valid actions.

jensen_shannon_divergence(full_posterior, posterior)

nx_to_geometric_data(graph, env, nodelist)

posterior_exact(env, estimator, nodelist[, batch_size])

push_source_flow_to_terminal_states(gfn_state_graph, ...)

Compute a hashable key for a graph.

Module Contents

class gfn.gym.helpers.bayesian_structure.jsd.FullPosterior
closures: GraphCollection
graphs: GraphCollection
log_probas: numpy.ndarray
markov: GraphCollection
to_dict()
Return type:

dict

class gfn.gym.helpers.bayesian_structure.jsd.GraphCollection
append(graph)
Parameters:

graph (networkx.DiGraph)

Return type:

None

freeze()
Return type:

GraphCollection

is_frozen()
Return type:

bool

mapping
to_dict(prefix=None)
Parameters:

prefix (str | None)

Return type:

dict

gfn.gym.helpers.bayesian_structure.jsd.NUM_DAGS = [1, 1, 3, 25, 543, 29281, 3781503]
gfn.gym.helpers.bayesian_structure.jsd.all_dags(env, num_variables, nodelist=None)
Parameters:
Return type:

list[torch_geometric.data.Data]

gfn.gym.helpers.bayesian_structure.jsd.construct_state_dag_with_bfs(gflownet_cache, nodelist, source_graph=None)

Constructs the state-action space of the GFlowNet.

This function performs Breadth-First Search on the GFlowNet state-action space starting from the source state, in order to construct a networkx.DiGraph object where each node is a GFlowNet state and each edge is labeled with the action and the log probability of taking that action. Each node is also labeled with the stop_action_log_flow which contains the probability of terminating at that state.

Parameters:
  • gflownet_cache (dict[frozenset, np.ndarray]) – The cache of log-probabilities returned by the GFlowNet.

  • nodelist (list[str]) – The list of nodes.

  • source_graph (nx.DiGraph instance) – The graph representing the source state.

Returns:

  • gfn_state_graph (nx.DiGraph instance) – The GFlowNet state-action space.

  • source_graph (nx.DiGraph instance) – The graph representing the source state.

Return type:

tuple[networkx.DiGraph, networkx.DiGraph]

gfn.gym.helpers.bayesian_structure.jsd.get_children(graph, gfn_cache, nodelist)

Gets all the children of a graph.

This function returns a set of the next states, from a particular state graph, with its corresponding log-probability. Note that the set of children includes the stop action, encoded as a None action, for which the child graph is the same as the current graph.

Parameters:
  • graph (nx.DiGraph instance) – The current graph.

  • gfn_cache (dict) – The cache of log-probabilities returned by the GFlowNet. See dag_gflownet.utils.gflownet.get_gflownet_cache for details.

  • nodelist (list) – The list of nodes; this list is required to ensure consistent encoding of nodes in the rows and columns of the adjacency matrix.

Returns:

children – The set of all the next state from the current graph, with their corresponding log-probability. Each child is represented as (next_graph, action, log_prob), where next_graph is a nx.DiGraph instance, action is the edge added (as a tuple of nodes), and log_prob is the log-probability of this action. Not that the “stop” action is encoded as the action None.

Return type:

set of tuples

gfn.gym.helpers.bayesian_structure.jsd.get_full_posterior(scorer, data, env, nodelist, verbose=True)
Parameters:
Return type:

FullPosterior

gfn.gym.helpers.bayesian_structure.jsd.get_gflownet_cache(env, estimator, nodelist, batch_size=256)

Cache the results of the GFlowNet for all the states.

This function caches the log-probabilities for all the actions and for all the states of the GFlowNet.

Parameters:
  • env (BayesianStructure) – The Bayesian structure learning environment.

  • estimator (Estimator) – The GFlowNet policy estimator.

  • nodelist (list[str]) – List of node names to ensure consistent node encoding in adjacency matrices.

  • batch_size (int, default=256) – Batch size for processing states through the GFlowNet.

Returns:

cache – The cache of log-probabilities returned by the GFlowNet. The keys of the cache are the graphs (encoded as a frozenset of their edges), and the corresponding value is an array of size (num_variables ** 2 + 1,) containing the log-probabilities of all the actions in that state (including the “stop” action, at the last index).

Return type:

dict of (frozenset, np.ndarray)

gfn.gym.helpers.bayesian_structure.jsd.get_gfn_exact_posterior(gfn_state_graph, verbose=True)
Parameters:
  • gfn_state_graph (networkx.DiGraph)

  • verbose (bool)

Return type:

FullPosterior

gfn.gym.helpers.bayesian_structure.jsd.get_markov_blanket(graph, node)
Parameters:
  • graph (networkx.DiGraph)

  • node (str)

Return type:

set[str]

gfn.gym.helpers.bayesian_structure.jsd.get_markov_blanket_graph(graph)

Build an undirected graph where two nodes are connected if one node is in the Markov blanket of another.

Parameters:

graph (networkx.DiGraph)

Return type:

networkx.DiGraph

gfn.gym.helpers.bayesian_structure.jsd.get_valid_actions(graph)

Gets the list of valid actions.

The valid actions correspond to directed edges that can be added to the current graph, such that adding any of those edges would still yield a DAG. In other words, those are edges that (1) are not already present in the graph, and (2) would not introduce a directed cycle.

Parameters:

graph (nx.DiGraph instance) – The current graph.

Returns:

edges – A set of directed edges, encoded as a tuple of nodes from graph, corresponding to the valid actions in the state graph.

Return type:

set of tuples

gfn.gym.helpers.bayesian_structure.jsd.jensen_shannon_divergence(full_posterior, posterior)
Parameters:
Return type:

float

gfn.gym.helpers.bayesian_structure.jsd.nx_to_geometric_data(graph, env, nodelist)
Parameters:
Return type:

torch_geometric.data.Data

gfn.gym.helpers.bayesian_structure.jsd.posterior_exact(env, estimator, nodelist, batch_size=256)
Parameters:
Return type:

networkx.DiGraph

gfn.gym.helpers.bayesian_structure.jsd.push_source_flow_to_terminal_states(gfn_state_graph, source_state_graph)

Compute a hashable key for a graph.

This function traverses the GFlowNet state-action space graph (DAG) in a topologically sorted order and “pushes” the log_flow from each node to its children according to the log_prob_action specified on the edges. The topological sort ensures that all the flow has “arrived” at a node before “moving” its flow to its children.

Parameters:
  • gfn_state_graph (nx.DiGraph instance) – The GFlowNet state-action space where each node represents one GFlowNet state and each edge represents one GFlowNet action.

  • source_state_graph (nx.DiGraph instance) – The graph representing the source state.

Returns:

gfn_state_graph – The GFlowNet state-action space but now each node has an attribute named log_flow, which is -np.inf for non-terminal states and the marginal log probability for the terminal states.

Return type:

nx.DiGraph instance