gfn.gym.helpers.bayesian_structure.graph¶
Functions¶
|
Sample an Erdos-Renyi graph. |
|
Sample a linear Gaussian Bayesian network based on an Erdos-Renyi graph. |
Module Contents¶
- gfn.gym.helpers.bayesian_structure.graph.sample_erdos_renyi_graph(num_nodes, rng, p=None, num_edges=None, node_names=None)¶
Sample an Erdos-Renyi graph.
- Parameters:
num_nodes (int) – Number of nodes in the graph.
rng (numpy.random.Generator) – Numpy random number generator.
p (Optional[float]) – Probability of creating an edge.
num_edges (Optional[int]) – Total number of edges (used to compute p if p is None).
node_names (Optional[List[str]]) – Optional list of node names.
- Returns:
- A PyTorch Geometric Data object representing the sampled graph with
an attribute ‘node_names’ mapping indices to node names.
- Return type:
Data
- Raises:
ValueError – If both p and num_edges are None.
- gfn.gym.helpers.bayesian_structure.graph.sample_erdos_renyi_linear_gaussian(num_nodes, rng, p=None, num_edges=None, node_names=None, loc_edges=0.0, scale_edges=1.0, obs_noise=0.1)¶
Sample a linear Gaussian Bayesian network based on an Erdos-Renyi graph.
Creates graph structure using torch-geometric and assigns CPD factors for each node. Each CPD factor is constructed by sampling a parameter vector theta (with bias fixed to zero) based on the node’s parent set determined from the graph structure.
- Parameters:
num_nodes (int) – Number of nodes.
rng (numpy.random.Generator) – Random number generator for reproducibility.
p (Optional[float]) – Probability of creating an edge.
num_edges (Optional[int]) – Total number of edges (used to compute p if p is None).
node_names (Optional[List[str]]) – Optional list of node names.
loc_edges (float) – Mean value for edge parameters.
scale_edges (float) – Standard deviation for edge parameters.
obs_noise (float) – Observation noise for each node.
- Returns:
- A PyTorch Geometric Data object with additional attributes ‘nodes’ and ‘cpds’.
’nodes’ is a list of node names. ‘cpds’ contains a list of LinearGaussianCPD factors for each node.
- Return type:
Data