gfn.gym.helpers.bayesian_structure.jsd ====================================== .. py:module:: gfn.gym.helpers.bayesian_structure.jsd .. autoapi-nested-parse:: The code is adapted from: https://github.com/GFNOrg/GFN_vs_HVI/blob/master/dags/dag_gflownet/utils/exhaustive.py Attributes ---------- .. autoapisummary:: gfn.gym.helpers.bayesian_structure.jsd.NUM_DAGS Classes ------- .. autoapisummary:: gfn.gym.helpers.bayesian_structure.jsd.FullPosterior gfn.gym.helpers.bayesian_structure.jsd.GraphCollection Functions --------- .. autoapisummary:: gfn.gym.helpers.bayesian_structure.jsd.all_dags gfn.gym.helpers.bayesian_structure.jsd.construct_state_dag_with_bfs gfn.gym.helpers.bayesian_structure.jsd.get_children gfn.gym.helpers.bayesian_structure.jsd.get_full_posterior gfn.gym.helpers.bayesian_structure.jsd.get_gflownet_cache gfn.gym.helpers.bayesian_structure.jsd.get_gfn_exact_posterior gfn.gym.helpers.bayesian_structure.jsd.get_markov_blanket gfn.gym.helpers.bayesian_structure.jsd.get_markov_blanket_graph gfn.gym.helpers.bayesian_structure.jsd.get_valid_actions gfn.gym.helpers.bayesian_structure.jsd.jensen_shannon_divergence gfn.gym.helpers.bayesian_structure.jsd.nx_to_geometric_data gfn.gym.helpers.bayesian_structure.jsd.posterior_exact gfn.gym.helpers.bayesian_structure.jsd.push_source_flow_to_terminal_states Module Contents --------------- .. py:class:: FullPosterior .. py:attribute:: closures :type: GraphCollection .. py:attribute:: graphs :type: GraphCollection .. py:attribute:: log_probas :type: numpy.ndarray .. py:attribute:: markov :type: GraphCollection .. py:method:: to_dict() .. py:class:: GraphCollection .. py:method:: append(graph) .. py:method:: freeze() .. py:method:: is_frozen() .. py:attribute:: mapping .. py:method:: to_dict(prefix = None) .. py:data:: NUM_DAGS :value: [1, 1, 3, 25, 543, 29281, 3781503] .. py:function:: all_dags(env, num_variables, nodelist = None) .. py:function:: construct_state_dag_with_bfs(gflownet_cache, nodelist, source_graph = None) Constructs the state-action space of the GFlowNet. This function performs Breadth-First Search on the GFlowNet state-action space starting from the source state, in order to construct a networkx.DiGraph object where each node is a GFlowNet state and each edge is labeled with the action and the log probability of taking that action. Each node is also labeled with the stop_action_log_flow which contains the probability of terminating at that state. :param gflownet_cache: The cache of log-probabilities returned by the GFlowNet. :type gflownet_cache: dict[frozenset, np.ndarray] :param nodelist: The list of nodes. :type nodelist: list[str] :param source_graph: The graph representing the source state. :type source_graph: nx.DiGraph instance :returns: * **gfn_state_graph** (*nx.DiGraph instance*) -- The GFlowNet state-action space. * **source_graph** (*nx.DiGraph instance*) -- The graph representing the source state. .. py:function:: get_children(graph, gfn_cache, nodelist) Gets all the children of a graph. This function returns a set of the next states, from a particular state `graph`, with its corresponding log-probability. Note that the set of children includes the stop action, encoded as a `None` action, for which the child graph is the same as the current graph. :param graph: The current graph. :type graph: nx.DiGraph instance :param gfn_cache: The cache of log-probabilities returned by the GFlowNet. See `dag_gflownet.utils.gflownet.get_gflownet_cache` for details. :type gfn_cache: dict :param nodelist: The list of nodes; this list is required to ensure consistent encoding of nodes in the rows and columns of the adjacency matrix. :type nodelist: list :returns: **children** -- The set of all the next state from the current graph, with their corresponding log-probability. Each child is represented as `(next_graph, action, log_prob)`, where `next_graph` is a nx.DiGraph instance, `action` is the edge added (as a tuple of nodes), and `log_prob` is the log-probability of this action. Not that the "stop" action is encoded as the action `None`. :rtype: set of tuples .. py:function:: get_full_posterior(scorer, data, env, nodelist, verbose = True) .. py:function:: get_gflownet_cache(env, estimator, nodelist, batch_size = 256) Cache the results of the GFlowNet for all the states. This function caches the log-probabilities for all the actions and for all the states of the GFlowNet. :param env: The Bayesian structure learning environment. :type env: BayesianStructure :param estimator: The GFlowNet policy estimator. :type estimator: Estimator :param nodelist: List of node names to ensure consistent node encoding in adjacency matrices. :type nodelist: list[str] :param batch_size: Batch size for processing states through the GFlowNet. :type batch_size: int, default=256 :returns: **cache** -- The cache of log-probabilities returned by the GFlowNet. The keys of the cache are the graphs (encoded as a frozenset of their edges), and the corresponding value is an array of size `(num_variables ** 2 + 1,)` containing the log-probabilities of all the actions in that state (including the "stop" action, at the last index). :rtype: dict of (frozenset, np.ndarray) .. py:function:: get_gfn_exact_posterior(gfn_state_graph, verbose = True) .. py:function:: get_markov_blanket(graph, node) .. py:function:: get_markov_blanket_graph(graph) Build an undirected graph where two nodes are connected if one node is in the Markov blanket of another. .. py:function:: get_valid_actions(graph) Gets the list of valid actions. The valid actions correspond to directed edges that can be added to the current graph, such that adding any of those edges would still yield a DAG. In other words, those are edges that (1) are not already present in the graph, and (2) would not introduce a directed cycle. :param graph: The current graph. :type graph: nx.DiGraph instance :returns: **edges** -- A set of directed edges, encoded as a tuple of nodes from `graph`, corresponding to the valid actions in the state `graph`. :rtype: set of tuples .. py:function:: jensen_shannon_divergence(full_posterior, posterior) .. py:function:: nx_to_geometric_data(graph, env, nodelist) .. py:function:: posterior_exact(env, estimator, nodelist, batch_size = 256) .. py:function:: push_source_flow_to_terminal_states(gfn_state_graph, source_state_graph) Compute a hashable key for a graph. This function traverses the GFlowNet state-action space graph (DAG) in a topologically sorted order and "pushes" the log_flow from each node to its children according to the log_prob_action specified on the edges. The topological sort ensures that all the flow has "arrived" at a node before "moving" its flow to its children. :param gfn_state_graph: The GFlowNet state-action space where each node represents one GFlowNet state and each edge represents one GFlowNet action. :type gfn_state_graph: nx.DiGraph instance :param source_state_graph: The graph representing the source state. :type source_state_graph: nx.DiGraph instance :returns: **gfn_state_graph** -- The GFlowNet state-action space but now each node has an attribute named log_flow, which is -np.inf for non-terminal states and the marginal log probability for the terminal states. :rtype: nx.DiGraph instance