gfn.gym.perfect_tree

Classes

PerfectBinaryTree

Perfect Tree Environment.

Module Contents

class gfn.gym.perfect_tree.PerfectBinaryTree(reward_fn, depth=4, device=None, debug=False)

Bases: gfn.env.DiscreteEnv

Perfect Tree Environment.

This environment is a perfect binary tree, where there is a bijection between trajectories and terminating states. Nodes are represented by integers, starting from 0 for the root. States are represented by a single integer tensor corresponding to the node index. Actions are integers: 0 (left child), 1 (right child), 2 (exit).

e.g.:

0 (root)

/

1 2

/ /

3 4 5 6

/ / / /

7 8 9 10 11 12 13 14 (terminating states if depth=3)

Recommended preprocessor: OneHotPreprocessor.

Parameters:
  • reward_fn (Callable)

  • depth (int)

  • device (Literal['cpu', 'cuda'] | torch.device | None)

  • debug (bool)

reward_fn

A function that computes the reward for a given state.

Type:

Callable

depth

The depth of the tree.

Type:

int

branching_factor

The branching factor of the tree.

Type:

int

n_actions

The number of actions.

Type:

int

n_nodes

The number of nodes in the tree.

Type:

int

transition_table

A dictionary that maps (state, action) to the next state.

Type:

dict

inverse_transition_table

A dictionary that maps (state, action) to the previous state.

Type:

dict

term_states

The terminating states.

Type:

DiscreteStates

States: type[gfn.env.DiscreteStates]
_build_tree()

Builds the tree and the transition tables.

Returns:

A tuple containing the transition table, the inverse transition table, and the terminating states.

Return type:

tuple[dict, dict, gfn.env.DiscreteStates]

property all_states: gfn.env.DiscreteStates

Returns all the states of the environment.

Return type:

gfn.env.DiscreteStates

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The previous states.

Return type:

gfn.env.DiscreteStates

branching_factor = 2
depth = 4
get_states_indices(states)

Returns the indices of the states.

Parameters:

states (gfn.states.States) – The states to get the indices of.

Returns:

The indices of the states.

make_states_class()

Returns the DiscreteStates class for the PerfectBinaryTree environment.

Return type:

type[gfn.env.DiscreteStates]

n_actions = 3
n_nodes = 31
reward(final_states)

Computes the reward for a batch of final states.

Parameters:

final_states – The final states.

Returns:

The reward of the final states.

reward_fn
s0
sf
step(states, actions)

Performs a step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The next states.

Return type:

gfn.env.DiscreteStates

property terminating_states: gfn.env.DiscreteStates

Returns the terminating states of the environment.

Return type:

gfn.env.DiscreteStates