gfn.utils.common ================ .. py:module:: gfn.utils.common Attributes ---------- .. autoapisummary:: gfn.utils.common.logger Classes ------- .. autoapisummary:: gfn.utils.common.Timer Functions --------- .. autoapisummary:: gfn.utils.common.default_fill_value_for_dtype gfn.utils.common.ensure_same_device gfn.utils.common.filter_kwargs_for_callable gfn.utils.common.get_available_cpus gfn.utils.common.is_int_dtype gfn.utils.common.make_dataloader_seed_fns gfn.utils.common.set_seed gfn.utils.common.temporarily_set_seed gfn.utils.common.timing_print_summary Module Contents --------------- .. py:class:: Timer(timing_dict, key, enabled=True) Helper class for timing code execution blocks and accumulating elapsed time in a dictionary. This class is designed to be used as a context manager to measure the execution time of code blocks. Upon entering the context, it records the start time, and upon exiting, it adds the elapsed time to a specified key in a provided timing dictionary. This is useful for profiling and tracking the time spent in different parts of a program, such as during training loops or data processing steps. timing_dict (dict): A dictionary where timing results will be accumulated. key (str): The key in the timing_dict under which to accumulate elapsed time. .. rubric:: Example for name in ["step1", "step2"]: timing[name] = 0 with Timer(timing, "step1"): # Code block to time do_something() print(f"Elapsed time for step1: {timing['step1']} seconds") .. py:method:: __enter__() .. py:method:: __exit__(exc_type, exc_val, exc_tb) .. py:attribute:: elapsed :value: None .. py:attribute:: enabled :value: True .. py:attribute:: key .. py:attribute:: timing_dict .. py:function:: default_fill_value_for_dtype(dtype) Return default fill value for dtype. - Float and complex dtypes → ``-inf`` - Integer dtypes → ``torch.iinfo(dtype).min`` - Bool dtype → ``0`` .. py:function:: ensure_same_device(device1, device2) Ensure that two tensors are on the same device. :param device1: The first device. :param device2: The second device. :raises ValueError: If the devices are not the same. .. py:function:: filter_kwargs_for_callable(callable_obj, kwargs) Filter a kwargs dict to only the parameters accepted by callable_obj. .. py:function:: get_available_cpus() Return the number of *usable* CPUs for the current process. The naive ``os.cpu_count()`` often reports the host's total logical cores, which can be misleading inside containers, job schedulers, or when CPU affinity is restricted. This helper tries to detect the real quota: 1. On Linux and recent *BSD it queries ``os.sched_getaffinity`` which already respects cgroups and task-set masks. 2. If that is not available it looks at common thread-limiting environment variables (``OMP_NUM_THREADS``/``MKL_NUM_THREADS``/``NUMBA_NUM_THREADS``). 3. Finally it falls back to ``os.cpu_count()`` and ensures the return value is at least ``1``. .. py:function:: is_int_dtype(tensor) Check if a tensor is an integer dtype. .. py:data:: logger .. py:function:: make_dataloader_seed_fns(base_seed, deterministic_mode = False) Return `(worker_init_fn, generator)` for DataLoader reproducibility. :param base_seed: The base seed to use for the DataLoader. :param deterministic_mode: If True, uses deterministic behavior for better reproducibility at the cost of performance. .. rubric:: Example >>> w_init, g = make_dataloader_seed_fns(process_seed) >>> DataLoader(dataset, ... num_workers=4, ... worker_init_fn=w_init, ... generator=g) Every worker receives its own deterministic seed ``base_seed + worker_id``. The returned ``torch.Generator`` is seeded with ``base_seed`` so that shuffling the order of the dataset is deterministic across runs. .. py:function:: set_seed(seed, deterministic_mode = False) Used to control randomness for both single and distributed training. :param seed: The seed to use for all random number generators :param deterministic_mode: If True, uses deterministic behavior for better performance. In multi-GPU settings, this only affects cuDNN. In multi-CPU settings, this allows parallel processing in NumPy. .. py:function:: temporarily_set_seed(seed) Context manager that temporarily sets seeds for multiple RNGs. :param seed: The seed value to use within the context .. rubric:: Example >>> with set_seed(42): ... # Random operations here will use seed 42 ... x = random.random() >>> # Original random state is restored here .. py:function:: timing_print_summary(timing_dict, rank, train_comm=None, num_training_ranks = 1) Print per-phase mean ± std timing summary. When *train_comm* is provided, rank 0 gathers all ranks' timing data via MPI and prints a consolidated table. Other ranks participate in the gather but do not print. When *train_comm* is ``None``, only the local rank's data is printed (single-process fallback).