gfn.gym.graph_building ====================== .. py:module:: gfn.gym.graph_building Classes ------- .. autoapisummary:: gfn.gym.graph_building.GraphBuilding gfn.gym.graph_building.GraphBuildingOnEdges Module Contents --------------- .. py:class:: GraphBuilding(num_node_classes, num_edge_classes, state_evaluator, is_directed = True, max_nodes = None, device = 'cpu', s0 = None, sf = None, debug = False) Bases: :py:obj:`gfn.env.GraphEnv` Environment for incrementally building graphs. This environment allows constructing graphs by: - Adding nodes of a given class - Adding edges of a given class between existing nodes - Terminating construction (EXIT) .. attribute:: num_node_classes The number of node classes. .. attribute:: num_edge_classes The number of edge classes. .. attribute:: state_evaluator A callable that computes rewards for final states. .. attribute:: is_directed Whether the graph is directed. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to undo. :returns: The previous states. .. py:method:: is_action_valid(states, actions, backward = False) Check if actions are valid for the given states. :param states: Current graph states. :param actions: Actions to validate. :param backward: Whether this is a backward step. :returns: True if all actions are valid, False otherwise. .. py:method:: make_actions_class() Returns the GraphActions class for this environment. :returns: A type of a subclass of GraphActions with environment-specific functionalities. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Generates random states. :param batch_shape: The shape of the batch. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `GraphStates` object with random states. .. py:method:: make_states_class() Creates a `GraphStates` class for this environment. .. py:attribute:: max_nodes :value: None .. py:method:: reward(final_states) The environment's reward given a state. :param final_states: A batch of final states. :returns: A tensor of shape `(batch_size,)` containing the rewards. .. py:attribute:: state_evaluator .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:class:: GraphBuildingOnEdges(n_nodes, state_evaluator, directed, device, debug = False) Bases: :py:obj:`GraphBuilding` Environment for building graphs edge by edge with discrete action space. The environment supports both directed and undirected graphs. In each state, the policy can: 1. Add an edge between existing nodes. 2. Use the exit action to terminate graph building. The action space is discrete, with size: - For directed graphs: `n_nodes^2 - n_nodes + 1` (all possible directed edges + exit). - For undirected graphs: `(n_nodes^2 - n_nodes)/2 + 1` (upper triangle + exit). .. attribute:: n_nodes The number of nodes in the graph. :type: int .. attribute:: n_possible_edges The number of possible edges. :type: int .. py:method:: is_action_valid(states, actions, backward = False) Checks if the actions are valid. :param states: The current states. :param actions: The actions to validate. :param backward: Whether the actions are backward actions. :returns: `True` if the actions are valid, `False` otherwise. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Makes a batch of random graph states with fixed number of nodes. :param batch_shape: Shape of the batch dimensions. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `GraphStates` object containing random graph states. .. py:method:: make_states_class() Creates a `GraphStates` class for this environment. .. py:attribute:: n_nodes