gfn.gym.helpers.bayesian_structure.evaluation ============================================= .. py:module:: gfn.gym.helpers.bayesian_structure.evaluation .. autoapi-nested-parse:: The code is adapted from: https://github.com/larslorch/dibs/blob/master/dibs/metrics.py Functions --------- .. autoapisummary:: gfn.gym.helpers.bayesian_structure.evaluation.expected_edges gfn.gym.helpers.bayesian_structure.evaluation.expected_shd gfn.gym.helpers.bayesian_structure.evaluation.posterior_estimate gfn.gym.helpers.bayesian_structure.evaluation.threshold_metrics Module Contents --------------- .. py:function:: expected_edges(posterior_samples) Compute the expected number of edges. This function computes the expected number of edges in graphs sampled from the posterior approximation. :param posterior_samples: Samples from the posterior. The tensor must have size `(B, N, N)`, where `B` is the number of sample graphs from the posterior approximation, and `N` is the number of variables in the graphs. :returns: The expected number of edges. :rtype: e_edges .. py:function:: expected_shd(posterior_samples, gt_graph) Compute the Expected Structural Hamming Distance. This function computes the Expected SHD between a posterior approximation given as a collection of samples from the posterior, and the ground-truth graph used in the original data generation process. :param posterior_samples: Samples from the posterior. The tensor must have size `(B, N, N)`, where `B` is the number of sample graphs from the posterior approximation, and `N` is the number of variables in the graphs. :param gt_graph: GeometricData instance representing the ground-truth graph. :returns: The Expected SHD. :rtype: e_shd .. py:function:: posterior_estimate(gflownet, env, num_samples=1000, batch_size=100, verbose=True) Get the posterior estimate of DAG-GFlowNet as a collection of graphs sampled from the GFlowNet. :param gflownet: `GFlowNet` instance. :param env: `BayesianStructure` environment. :param rng: Optional random generator instance. :param num_samples: The number of samples in the posterior approximation. :param verbose: If True, display a progress bar for the sampling process. :returns: torch.Tensor with shape `(B, N, N)`, where `B` is the number of sample graphs in the posterior approximation, and `N` is the number of variables in a graph. :rtype: posterior .. py:function:: threshold_metrics(posterior_samples, gt_graph) Compute threshold metrics (e.g. AUROC, Precision, Recall, etc...). :param posterior_samples: Samples from the posterior. The tensor must have size `(B, N, N)`, where `B` is the number of sample graphs from the posterior approximation, and `N` is the number of variables in the graphs. :param gt_graph: GeometricData instance representing the ground-truth graph. :returns: - False Positive Rate - True Positive Rate - Area Under the Receiver Operating Characteristic Curve - Precision - Recall - Area Under the Precision-Recall Curve - Average Precision :rtype: A dictionary containing the following metrics