multinode_scaling_analysis

Multinode Scaling Analysis Script

This script analyzes the results of multinode scaling experiments from the torchgfn project on Weights & Biases. It focuses on n_modes_found progression to understand how many iterations it takes to discover all modes.

Attributes

STRATEGY_MAPPING

STRATEGY_PARAMS

Functions

_add_variance_plot_legend(fig, legend_entries, ...)

Add a combined legend for variance plots.

analyze_communities_within_environments(env_community_runs)

Analyze and compare communities within each environment.

analyze_community_variance(env_community_runs[, ...])

Main entry point for community variance analysis.

analyze_environment_configurations(runs_list)

Analyze environment configurations for Hypergrid.

analyze_groups(runs_list)

Analyze group structure from run data.

collect_community_timeseries_data(env_community_runs)

Collect timeseries data for all agents within each community.

compare_community_vs_environment_groupings(...)

Compare community groups (wandb) vs environment configurations.

create_hierarchical_structure(runs_list, ...)

Create hierarchical structure: Environment Groups → Community Groups → Runs

extract_run_data(runs_list)

Extract basic run information into dataframes.

extract_strategy_from_runs(runs)

Extract strategy configuration from a list of runs.

fetch_wandb_runs([project_name])

Fetch all runs from the specified wandb project.

format_run_ids_display(run_ids[, max_display])

Format a list of run IDs for display, truncating if longer than max_display.

format_strategy_for_legend(strategy_id)

Format strategy ID for clean legend display.

get_community_size(community_runs)

Get the size of a community (number of runs/agents).

identify_runs_for_deletion(runs_list, run_to_community)

Identify runs that should be deleted from wandb.

main()

Main function to run the complete analysis.

parse_args()

Parse command line arguments.

plot_communities_in_environment(env_config_id, ...)

Plot n_modes_found progression for communities within a single environment.

plot_community_variance_over_time(variance_data[, ...])

Plot within-community variance (std dev) over time for a given metric.

plot_run_durations(df_combined)

Plot run duration analysis for completed runs.

plot_run_states_distribution(df_runs)

Plot distribution of run states.

plot_variance_comparison_by_size_and_strategy(...[, ...])

Create summary plots comparing variance metrics across sizes and strategies.

print_available_variables(runs_list, group_runs[, ...])

Print all available variables at run and group level.

print_environment_configurations(runs_list, ...[, ...])

Print detailed information about environment configurations.

print_experiment_summary(df_combined)

Print overall experiment summary.

Module Contents

multinode_scaling_analysis.STRATEGY_MAPPING
multinode_scaling_analysis.STRATEGY_PARAMS = ['average_every', 'replacement_ratio', 'restart_init_mode', 'use_random_strategies',...
multinode_scaling_analysis._add_variance_plot_legend(fig, legend_entries, size_color_map, strategy_linestyle_map, strategy_marker_map)

Add a combined legend for variance plots.

multinode_scaling_analysis.analyze_communities_within_environments(env_community_runs, exit_after_printing_strategies=False)

Analyze and compare communities within each environment.

Parameters:
  • env_community_runs – Dict mapping environment configs to community runs

  • exit_after_printing_strategies – If True, exit after printing strategy summary (useful for updating STRATEGY_MAPPING)

multinode_scaling_analysis.analyze_community_variance(env_community_runs, metrics=None, global_strategy_linestyle_map=None, global_strategy_marker_map=None)

Main entry point for community variance analysis.

Collects timeseries data and generates variance plots for comparing agent performance distribution within communities.

Parameters:
  • env_community_runs – Dict mapping env_config_id -> community_id -> list of runs

  • metrics (list[str] | None) – Metrics to analyze. Default: [“n_modes_found”, “n_new_modes”, “novelty_sum”]

  • global_strategy_linestyle_map (dict | None) – Optional pre-computed strategy -> linestyle mapping

  • global_strategy_marker_map (dict | None) – Optional pre-computed strategy -> marker mapping

multinode_scaling_analysis.analyze_environment_configurations(runs_list)

Analyze environment configurations for Hypergrid.

Config is based on config.R0, config.R1, config.R2, config.height, config.ndim.

multinode_scaling_analysis.analyze_groups(runs_list)

Analyze group structure from run data.

multinode_scaling_analysis.collect_community_timeseries_data(env_community_runs, metrics=None, max_steps=1000)

Collect timeseries data for all agents within each community.

This function gathers history data from all runs (agents) within each community, aligned by step index, to enable variance analysis across agents.

Parameters:
  • env_community_runs – Dict mapping env_config_id -> community_id -> list of runs

  • metrics (list[str] | None) – List of metric names to collect. Default: [“n_modes_found”, “n_new_modes”, “novelty_sum”]

  • max_steps (int) – Maximum number of steps to collect per run (for memory efficiency)

Returns:

env_config_id -> community_id -> {

‘metadata’: {‘size’: int, ‘strategy’: str}, ‘metrics’: {

metric_name: {

‘steps’: [0, 1, 2, …], ‘agent_data’: np.array of shape (n_agents, n_steps), ‘mean’: np.array of shape (n_steps,), ‘std’: np.array of shape (n_steps,), ‘min’: np.array of shape (n_steps,), ‘max’: np.array of shape (n_steps,),

}

}

}

Return type:

Dict with structure

multinode_scaling_analysis.compare_community_vs_environment_groupings(run_to_community, run_to_env_config)

Compare community groups (wandb) vs environment configurations.

multinode_scaling_analysis.create_hierarchical_structure(runs_list, run_to_env_config, run_to_community)

Create hierarchical structure: Environment Groups → Community Groups → Runs

Returns: - env_to_communities: dict mapping environment_config_id to list of community_ids - community_to_runs: dict mapping community_id to list of run_ids - env_community_runs: nested dict[env_config_id][community_id] = list of runs

multinode_scaling_analysis.extract_run_data(runs_list)

Extract basic run information into dataframes.

multinode_scaling_analysis.extract_strategy_from_runs(runs)

Extract strategy configuration from a list of runs.

multinode_scaling_analysis.fetch_wandb_runs(project_name='torchgfn/torchgfn')

Fetch all runs from the specified wandb project.

multinode_scaling_analysis.format_run_ids_display(run_ids, max_display=3)

Format a list of run IDs for display, truncating if longer than max_display.

Parameters:
  • run_ids (list)

  • max_display (int)

Return type:

str

multinode_scaling_analysis.format_strategy_for_legend(strategy_id)

Format strategy ID for clean legend display.

Uses STRATEGY_MAPPING if available, otherwise falls back to auto-formatting.

multinode_scaling_analysis.get_community_size(community_runs)

Get the size of a community (number of runs/agents).

multinode_scaling_analysis.identify_runs_for_deletion(runs_list, run_to_community, min_steps=100, include_states=None)

Identify runs that should be deleted from wandb.

This function identifies runs that are likely incomplete or failed experiments based on the _step metric from run summary (universally available on all ranks).

Parameters:
  • runs_list – List of wandb run objects

  • run_to_community – Dict mapping run_id to community_id

  • min_steps (int) – Minimum number of steps (_step from summary) required. Runs with fewer steps are flagged for deletion. This metric is available on ALL ranks, not just rank-0.

  • include_states (list[str] | None) – List of run states to consider for deletion. Default: [“crashed”, “failed”]. Use [“all”] to include all states (including “finished” runs with insufficient data).

Returns:

  • runs_to_delete: List of run info dicts

  • communities_to_delete: List of community IDs where ALL runs should be deleted

  • summary: Summary statistics

Return type:

dict with

multinode_scaling_analysis.main()

Main function to run the complete analysis.

multinode_scaling_analysis.parse_args()

Parse command line arguments.

multinode_scaling_analysis.plot_communities_in_environment(env_config_id, community_data, community_metadata, strategy_linestyle_map, strategy_marker_map)

Plot n_modes_found progression for communities within a single environment.

Uses color for community size, linestyle and markers for strategy. Linestyle and marker mappings are passed in to ensure consistency across plots.

multinode_scaling_analysis.plot_community_variance_over_time(variance_data, metric='n_new_modes', global_strategy_linestyle_map=None, global_strategy_marker_map=None)

Plot within-community variance (std dev) over time for a given metric.

Creates plots showing how agent performance variance evolves within each community, allowing comparison across community sizes and strategies.

Parameters:
  • variance_data – Output from collect_community_timeseries_data()

  • metric (str) – Which metric to plot variance for

  • global_strategy_linestyle_map (dict | None) – Strategy -> linestyle mapping for consistency

  • global_strategy_marker_map (dict | None) – Strategy -> marker mapping for consistency

multinode_scaling_analysis.plot_run_durations(df_combined)

Plot run duration analysis for completed runs.

multinode_scaling_analysis.plot_run_states_distribution(df_runs)

Plot distribution of run states.

multinode_scaling_analysis.plot_variance_comparison_by_size_and_strategy(variance_data, metric='n_new_modes')

Create summary plots comparing variance metrics across sizes and strategies.

Creates bar charts and box plots showing: - Average std dev by community size - Average std dev by strategy - Final std dev comparison

Parameters:
  • variance_data – Output from collect_community_timeseries_data()

  • metric (str) – Which metric to analyze

multinode_scaling_analysis.print_available_variables(runs_list, group_runs, check_availability=True)

Print all available variables at run and group level.

Parameters:
  • runs_list – List of wandb run objects

  • group_runs – Dict mapping group_id to list of run_ids

  • check_availability – If True, check which variables are available across multiple runs (useful for identifying rank-0 only metrics)

multinode_scaling_analysis.print_environment_configurations(runs_list, env_config_runs, env_config_details, env_to_communities=None, community_to_runs=None)

Print detailed information about environment configurations.

multinode_scaling_analysis.print_experiment_summary(df_combined)

Print overall experiment summary.