multinode_scaling_analysis¶
Multinode Scaling Analysis Script
This script analyzes the results of multinode scaling experiments from the torchgfn project on Weights & Biases. It focuses on n_modes_found progression to understand how many iterations it takes to discover all modes.
Attributes¶
Functions¶
|
Add a combined legend for variance plots. |
|
Analyze and compare communities within each environment. |
|
Main entry point for community variance analysis. |
|
Analyze environment configurations for Hypergrid. |
|
Analyze group structure from run data. |
|
Collect timeseries data for all agents within each community. |
Compare community groups (wandb) vs environment configurations. |
|
|
Create hierarchical structure: Environment Groups → Community Groups → Runs |
|
Extract basic run information into dataframes. |
Extract strategy configuration from a list of runs. |
|
|
Fetch all runs from the specified wandb project. |
|
Format a list of run IDs for display, truncating if longer than max_display. |
|
Format strategy ID for clean legend display. |
|
Get the size of a community (number of runs/agents). |
|
Identify runs that should be deleted from wandb. |
|
Main function to run the complete analysis. |
Parse command line arguments. |
|
|
Plot n_modes_found progression for communities within a single environment. |
|
Plot within-community variance (std dev) over time for a given metric. |
|
Plot run duration analysis for completed runs. |
|
Plot distribution of run states. |
|
Create summary plots comparing variance metrics across sizes and strategies. |
|
Print all available variables at run and group level. |
|
Print detailed information about environment configurations. |
|
Print overall experiment summary. |
Module Contents¶
- multinode_scaling_analysis.STRATEGY_MAPPING¶
- multinode_scaling_analysis.STRATEGY_PARAMS = ['average_every', 'replacement_ratio', 'restart_init_mode', 'use_random_strategies',...¶
- multinode_scaling_analysis._add_variance_plot_legend(fig, legend_entries, size_color_map, strategy_linestyle_map, strategy_marker_map)¶
Add a combined legend for variance plots.
- multinode_scaling_analysis.analyze_communities_within_environments(env_community_runs, exit_after_printing_strategies=False)¶
Analyze and compare communities within each environment.
- Parameters:
env_community_runs – Dict mapping environment configs to community runs
exit_after_printing_strategies – If True, exit after printing strategy summary (useful for updating STRATEGY_MAPPING)
- multinode_scaling_analysis.analyze_community_variance(env_community_runs, metrics=None, global_strategy_linestyle_map=None, global_strategy_marker_map=None)¶
Main entry point for community variance analysis.
Collects timeseries data and generates variance plots for comparing agent performance distribution within communities.
- Parameters:
env_community_runs – Dict mapping env_config_id -> community_id -> list of runs
metrics (list[str] | None) – Metrics to analyze. Default: [“n_modes_found”, “n_new_modes”, “novelty_sum”]
global_strategy_linestyle_map (dict | None) – Optional pre-computed strategy -> linestyle mapping
global_strategy_marker_map (dict | None) – Optional pre-computed strategy -> marker mapping
- multinode_scaling_analysis.analyze_environment_configurations(runs_list)¶
Analyze environment configurations for Hypergrid.
Config is based on config.R0, config.R1, config.R2, config.height, config.ndim.
- multinode_scaling_analysis.analyze_groups(runs_list)¶
Analyze group structure from run data.
- multinode_scaling_analysis.collect_community_timeseries_data(env_community_runs, metrics=None, max_steps=1000)¶
Collect timeseries data for all agents within each community.
This function gathers history data from all runs (agents) within each community, aligned by step index, to enable variance analysis across agents.
- Parameters:
env_community_runs – Dict mapping env_config_id -> community_id -> list of runs
metrics (list[str] | None) – List of metric names to collect. Default: [“n_modes_found”, “n_new_modes”, “novelty_sum”]
max_steps (int) – Maximum number of steps to collect per run (for memory efficiency)
- Returns:
- env_config_id -> community_id -> {
‘metadata’: {‘size’: int, ‘strategy’: str}, ‘metrics’: {
- metric_name: {
‘steps’: [0, 1, 2, …], ‘agent_data’: np.array of shape (n_agents, n_steps), ‘mean’: np.array of shape (n_steps,), ‘std’: np.array of shape (n_steps,), ‘min’: np.array of shape (n_steps,), ‘max’: np.array of shape (n_steps,),
}
}
}
- Return type:
Dict with structure
- multinode_scaling_analysis.compare_community_vs_environment_groupings(run_to_community, run_to_env_config)¶
Compare community groups (wandb) vs environment configurations.
- multinode_scaling_analysis.create_hierarchical_structure(runs_list, run_to_env_config, run_to_community)¶
Create hierarchical structure: Environment Groups → Community Groups → Runs
Returns: - env_to_communities: dict mapping environment_config_id to list of community_ids - community_to_runs: dict mapping community_id to list of run_ids - env_community_runs: nested dict[env_config_id][community_id] = list of runs
- multinode_scaling_analysis.extract_run_data(runs_list)¶
Extract basic run information into dataframes.
- multinode_scaling_analysis.extract_strategy_from_runs(runs)¶
Extract strategy configuration from a list of runs.
- multinode_scaling_analysis.fetch_wandb_runs(project_name='torchgfn/torchgfn')¶
Fetch all runs from the specified wandb project.
- multinode_scaling_analysis.format_run_ids_display(run_ids, max_display=3)¶
Format a list of run IDs for display, truncating if longer than max_display.
- Parameters:
run_ids (list)
max_display (int)
- Return type:
str
- multinode_scaling_analysis.format_strategy_for_legend(strategy_id)¶
Format strategy ID for clean legend display.
Uses STRATEGY_MAPPING if available, otherwise falls back to auto-formatting.
- multinode_scaling_analysis.get_community_size(community_runs)¶
Get the size of a community (number of runs/agents).
- multinode_scaling_analysis.identify_runs_for_deletion(runs_list, run_to_community, min_steps=100, include_states=None)¶
Identify runs that should be deleted from wandb.
This function identifies runs that are likely incomplete or failed experiments based on the _step metric from run summary (universally available on all ranks).
- Parameters:
runs_list – List of wandb run objects
run_to_community – Dict mapping run_id to community_id
min_steps (int) – Minimum number of steps (_step from summary) required. Runs with fewer steps are flagged for deletion. This metric is available on ALL ranks, not just rank-0.
include_states (list[str] | None) – List of run states to consider for deletion. Default: [“crashed”, “failed”]. Use [“all”] to include all states (including “finished” runs with insufficient data).
- Returns:
runs_to_delete: List of run info dicts
communities_to_delete: List of community IDs where ALL runs should be deleted
summary: Summary statistics
- Return type:
dict with
- multinode_scaling_analysis.main()¶
Main function to run the complete analysis.
- multinode_scaling_analysis.parse_args()¶
Parse command line arguments.
- multinode_scaling_analysis.plot_communities_in_environment(env_config_id, community_data, community_metadata, strategy_linestyle_map, strategy_marker_map)¶
Plot n_modes_found progression for communities within a single environment.
Uses color for community size, linestyle and markers for strategy. Linestyle and marker mappings are passed in to ensure consistency across plots.
- multinode_scaling_analysis.plot_community_variance_over_time(variance_data, metric='n_new_modes', global_strategy_linestyle_map=None, global_strategy_marker_map=None)¶
Plot within-community variance (std dev) over time for a given metric.
Creates plots showing how agent performance variance evolves within each community, allowing comparison across community sizes and strategies.
- Parameters:
variance_data – Output from collect_community_timeseries_data()
metric (str) – Which metric to plot variance for
global_strategy_linestyle_map (dict | None) – Strategy -> linestyle mapping for consistency
global_strategy_marker_map (dict | None) – Strategy -> marker mapping for consistency
- multinode_scaling_analysis.plot_run_durations(df_combined)¶
Plot run duration analysis for completed runs.
- multinode_scaling_analysis.plot_run_states_distribution(df_runs)¶
Plot distribution of run states.
- multinode_scaling_analysis.plot_variance_comparison_by_size_and_strategy(variance_data, metric='n_new_modes')¶
Create summary plots comparing variance metrics across sizes and strategies.
Creates bar charts and box plots showing: - Average std dev by community size - Average std dev by strategy - Final std dev comparison
- Parameters:
variance_data – Output from collect_community_timeseries_data()
metric (str) – Which metric to analyze
- multinode_scaling_analysis.print_available_variables(runs_list, group_runs, check_availability=True)¶
Print all available variables at run and group level.
- Parameters:
runs_list – List of wandb run objects
group_runs – Dict mapping group_id to list of run_ids
check_availability – If True, check which variables are available across multiple runs (useful for identifying rank-0 only metrics)
- multinode_scaling_analysis.print_environment_configurations(runs_list, env_config_runs, env_config_details, env_to_communities=None, community_to_runs=None)¶
Print detailed information about environment configurations.
- multinode_scaling_analysis.print_experiment_summary(df_combined)¶
Print overall experiment summary.