gfn.utils.common

Attributes

logger

Classes

Timer

Helper class for timing code execution blocks and accumulating elapsed time in a dictionary.

Functions

default_fill_value_for_dtype(dtype)

Return default fill value for dtype.

ensure_same_device(device1, device2)

Ensure that two tensors are on the same device.

filter_kwargs_for_callable(callable_obj, kwargs)

Filter a kwargs dict to only the parameters accepted by callable_obj.

get_available_cpus()

Return the number of usable CPUs for the current process.

is_int_dtype(tensor)

Check if a tensor is an integer dtype.

make_dataloader_seed_fns(base_seed[, deterministic_mode])

Return (worker_init_fn, generator) for DataLoader reproducibility.

set_seed(seed[, deterministic_mode])

Used to control randomness for both single and distributed training.

temporarily_set_seed(seed)

Context manager that temporarily sets seeds for multiple RNGs.

timing_print_summary(timing_dict, rank[, train_comm, ...])

Print per-phase mean ± std timing summary.

Module Contents

class gfn.utils.common.Timer(timing_dict, key, enabled=True)

Helper class for timing code execution blocks and accumulating elapsed time in a dictionary.

This class is designed to be used as a context manager to measure the execution time of code blocks. Upon entering the context, it records the start time, and upon exiting, it adds the elapsed time to a specified key in a provided timing dictionary. This is useful for profiling and tracking the time spent in different parts of a program, such as during training loops or data processing steps.

timing_dict (dict): A dictionary where timing results will be accumulated. key (str): The key in the timing_dict under which to accumulate elapsed time.

Example

for name in [“step1”, “step2”]:

timing[name] = 0

with Timer(timing, “step1”):

# Code block to time do_something()

print(f”Elapsed time for step1: {timing[‘step1’]} seconds”)

__enter__()
__exit__(exc_type, exc_val, exc_tb)
elapsed = None
enabled = True
key
timing_dict
gfn.utils.common.default_fill_value_for_dtype(dtype)

Return default fill value for dtype.

  • Float and complex dtypes → -inf

  • Integer dtypes → torch.iinfo(dtype).min

  • Bool dtype → 0

Parameters:

dtype (torch.dtype)

Return type:

int | float

gfn.utils.common.ensure_same_device(device1, device2)

Ensure that two tensors are on the same device.

Parameters:
  • device1 (torch.device) – The first device.

  • device2 (torch.device) – The second device.

Raises:

ValueError – If the devices are not the same.

Return type:

None

gfn.utils.common.filter_kwargs_for_callable(callable_obj, kwargs)

Filter a kwargs dict to only the parameters accepted by callable_obj.

Parameters:
  • callable_obj (Any)

  • kwargs (dict[str, Any])

Return type:

dict[str, Any]

gfn.utils.common.get_available_cpus()

Return the number of usable CPUs for the current process.

The naive os.cpu_count() often reports the host’s total logical cores, which can be misleading inside containers, job schedulers, or when CPU affinity is restricted. This helper tries to detect the real quota:

  1. On Linux and recent *BSD it queries os.sched_getaffinity which already respects cgroups and task-set masks.

  2. If that is not available it looks at common thread-limiting environment variables (OMP_NUM_THREADS/MKL_NUM_THREADS/NUMBA_NUM_THREADS).

  3. Finally it falls back to os.cpu_count() and ensures the return value is at least 1.

Return type:

int

gfn.utils.common.is_int_dtype(tensor)

Check if a tensor is an integer dtype.

Parameters:

tensor (torch.Tensor)

Return type:

bool

gfn.utils.common.logger
gfn.utils.common.make_dataloader_seed_fns(base_seed, deterministic_mode=False)

Return (worker_init_fn, generator) for DataLoader reproducibility.

Parameters:
  • base_seed (int) – The base seed to use for the DataLoader.

  • deterministic_mode (bool) – If True, uses deterministic behavior for better reproducibility at the cost of performance.

Return type:

Tuple[Callable[[int], None], torch.Generator]

Example

>>> w_init, g = make_dataloader_seed_fns(process_seed)
>>> DataLoader(dataset,
...            num_workers=4,
...            worker_init_fn=w_init,
...            generator=g)

Every worker receives its own deterministic seed base_seed + worker_id. The returned torch.Generator is seeded with base_seed so that shuffling the order of the dataset is deterministic across runs.

gfn.utils.common.set_seed(seed, deterministic_mode=False)

Used to control randomness for both single and distributed training.

Parameters:
  • seed (int) – The seed to use for all random number generators

  • deterministic_mode (bool) – If True, uses deterministic behavior for better performance. In multi-GPU settings, this only affects cuDNN. In multi-CPU settings, this allows parallel processing in NumPy.

Return type:

None

gfn.utils.common.temporarily_set_seed(seed)

Context manager that temporarily sets seeds for multiple RNGs.

Parameters:

seed – The seed value to use within the context

Example

>>> with set_seed(42):
...     # Random operations here will use seed 42
...     x = random.random()
>>> # Original random state is restored here
gfn.utils.common.timing_print_summary(timing_dict, rank, train_comm=None, num_training_ranks=1)

Print per-phase mean ± std timing summary.

When train_comm is provided, rank 0 gathers all ranks’ timing data via MPI and prints a consolidated table. Other ranks participate in the gather but do not print. When train_comm is None, only the local rank’s data is printed (single-process fallback).

Parameters:
  • timing_dict (dict[str, list[float]])

  • rank (int)

  • num_training_ranks (int)

Return type:

None