Skip to content

Commit

Permalink
Small Restructure (#93)
Browse files Browse the repository at this point in the history
Restructures imports/exports and reorganises some search code.

No functional change.

Bench: 1733801
  • Loading branch information
jw1912 authored Jan 19, 2025
1 parent 38d6792 commit 7c7996e
Show file tree
Hide file tree
Showing 16 changed files with 221 additions and 206 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ The required networks will be downloaded automatically (and validated).
## Development

Development of Monty is facilitated by [montytest](https://tests.montychess.org/tests).
If you want to contribute, it is recommended to look in:
- [src/mcts/helpers.rs](src/mcts/helpers.rs) - location of functions that
calculate many important search heuristics, e.g. CPUCT scaling
- [src/mcts.rs](src/mcts.rs) - the actual search logic
If you want to contribute, it is recommended to look in the [mcts](src/mcts.rs) module
and its submodules.

Functional patches are required to pass on montytest, with an STC followed by an LTC test.

Expand Down
14 changes: 8 additions & 6 deletions datagen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ use rng::Rand;
use thread::DatagenThread;

use monty::{
read_into_struct_unchecked, uci, ChessState, MappedWeights, MctsParams, PolicyNetwork,
ValueNetwork,
chess::ChessState,
mcts::MctsParams,
networks::{self, PolicyNetwork, ValueNetwork},
read_into_struct_unchecked, uci, MappedWeights,
};

use std::{
Expand All @@ -25,11 +27,11 @@ fn main() {
let mut args = std::env::args();
args.next();

let policy_mapped: MappedWeights<monty::PolicyNetwork> =
unsafe { read_into_struct_unchecked(monty::PolicyFileDefaultName) };
let policy_mapped: MappedWeights<networks::PolicyNetwork> =
unsafe { read_into_struct_unchecked(networks::PolicyFileDefaultName) };

let value_mapped: MappedWeights<monty::ValueNetwork> =
unsafe { read_into_struct_unchecked(monty::ValueFileDefaultName) };
let value_mapped: MappedWeights<networks::ValueNetwork> =
unsafe { read_into_struct_unchecked(networks::ValueFileDefaultName) };

let policy = &policy_mapped.data;
let value = &value_mapped.data;
Expand Down
10 changes: 6 additions & 4 deletions datagen/src/thread.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::{Destination, Rand};

use monty::{
ChessState, GameState, Limits, MctsParams, PolicyNetwork, Searcher, Tree, ValueNetwork,
chess::{ChessState, GameState},
mcts::{Limits, MctsParams, Searcher},
networks::{PolicyNetwork, ValueNetwork},
tree::Tree,
};
use montyformat::{MontyFormat, MontyValueFormat, SearchData};

Expand Down Expand Up @@ -128,9 +131,8 @@ impl<'a> DatagenThread<'a> {
}

let abort = AtomicBool::new(false);
tree.try_use_subtree(&position, &None);
let searcher =
Searcher::new(position.clone(), &tree, &self.params, policy, value, &abort);
tree.set_root_position(&position);
let searcher = Searcher::new(&tree, &self.params, policy, value, &abort);

let (bm, score) = searcher.search(1, limits, false, &mut 0);

Expand Down
5 changes: 4 additions & 1 deletion src/bin/quantise-policy.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::io::Write;

use monty::{read_into_struct_unchecked, MappedWeights, PolicyNetwork, UnquantisedPolicyNetwork};
use monty::{
networks::{PolicyNetwork, UnquantisedPolicyNetwork},
read_into_struct_unchecked, MappedWeights,
};

fn main() {
let unquantised: MappedWeights<UnquantisedPolicyNetwork> =
Expand Down
4 changes: 2 additions & 2 deletions src/chess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ mod frc;
mod moves;

use crate::{
networks::{Accumulator, POLICY_L1},
MctsParams, PolicyNetwork, ValueNetwork,
mcts::MctsParams,
networks::{Accumulator, PolicyNetwork, ValueNetwork, POLICY_L1},
};

pub use self::{attacks::Attacks, board::Board, frc::Castling, moves::Move};
Expand Down
2 changes: 1 addition & 1 deletion src/chess/board.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{pop_lsb, GameState};
use crate::{chess::GameState, pop_lsb};

use super::{
attacks::Attacks,
Expand Down
15 changes: 4 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
pub(crate) mod chess;
mod mcts;
mod networks;
mod tree;
pub mod chess;
pub mod mcts;
pub mod networks;
pub mod tree;
pub mod uci;

pub use chess::{Board, Castling, ChessState, GameState, Move};
pub use mcts::{Limits, MctsParams, Searcher};
use memmap2::Mmap;
pub use networks::{
PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork, ValueFileDefaultName,
ValueNetwork,
};
pub use tree::Tree;

pub struct MappedWeights<'a, T> {
pub mmap: Mmap, // The memory-mapped file
Expand Down
20 changes: 14 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ fn main() {
#[cfg(feature = "embed")]
mod net {
use memmap2::Mmap;
use monty::{uci, ChessState, MctsParams, PolicyNetwork, ValueNetwork};
use monty::{
chess::ChessState,
mcts::MctsParams,
networks::{PolicyNetwork, ValueNetwork},
uci,
};
use once_cell::sync::Lazy;
use sha2::{Digest, Sha256};
use std::fs::{self, File};
Expand Down Expand Up @@ -236,17 +241,20 @@ mod net {

#[cfg(not(feature = "embed"))]
mod nonet {
use monty::{read_into_struct_unchecked, uci, ChessState, MappedWeights, MctsParams};
use monty::{
chess::ChessState, mcts::MctsParams, networks, read_into_struct_unchecked, uci,
MappedWeights,
};

pub fn run() {
let mut args = std::env::args();
let arg1 = args.nth(1);

let policy_mapped: MappedWeights<monty::PolicyNetwork> =
unsafe { read_into_struct_unchecked(monty::PolicyFileDefaultName) };
let policy_mapped: MappedWeights<networks::PolicyNetwork> =
unsafe { read_into_struct_unchecked(networks::PolicyFileDefaultName) };

let value_mapped: MappedWeights<monty::ValueNetwork> =
unsafe { read_into_struct_unchecked(monty::ValueFileDefaultName) };
let value_mapped: MappedWeights<networks::ValueNetwork> =
unsafe { read_into_struct_unchecked(networks::ValueFileDefaultName) };

let policy = policy_mapped.data;
let value = value_mapped.data;
Expand Down
154 changes: 18 additions & 136 deletions src/mcts.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
mod helpers;
mod iteration;
mod params;

pub use helpers::SearchHelpers;
pub use params::MctsParams;

use crate::{
chess::Move,
tree::{Node, NodePtr, Tree},
ChessState, GameState, PolicyNetwork, ValueNetwork,
chess::{GameState, Move},
networks::{PolicyNetwork, ValueNetwork},
tree::{NodePtr, Tree},
};

use std::{
Expand All @@ -34,7 +35,6 @@ pub struct SearchStats {
}

pub struct Searcher<'a> {
root_position: ChessState,
tree: &'a Tree,
params: &'a MctsParams,
policy: &'a PolicyNetwork,
Expand All @@ -44,15 +44,13 @@ pub struct Searcher<'a> {

impl<'a> Searcher<'a> {
pub fn new(
root_position: ChessState,
tree: &'a Tree,
params: &'a MctsParams,
policy: &'a PolicyNetwork,
value: &'a ValueNetwork,
abort: &'a AtomicBool,
) -> Self {
Self {
root_position,
tree,
params,
policy,
Expand Down Expand Up @@ -105,11 +103,10 @@ impl<'a> Searcher<'a> {
F: FnMut() -> bool,
{
loop {
let mut pos = self.root_position.clone();
let mut pos = self.tree.root_position().clone();
let mut this_depth = 0;

if self
.perform_one_iteration(&mut pos, self.tree.root_node(), &mut this_depth)
if iteration::perform_one(self, &mut pos, self.tree.root_node(), &mut this_depth)
.is_none()
{
return false;
Expand Down Expand Up @@ -250,6 +247,7 @@ impl<'a> Searcher<'a> {
#[cfg(not(feature = "uci-minimal"))]
let mut timer_last_output = Instant::now();

let pos = self.tree.root_position();
let node = self.tree.root_node();

// the root node is added to an empty tree, **and not counted** towards the
Expand All @@ -260,16 +258,15 @@ impl<'a> Searcher<'a> {
assert_eq!(node, ptr);

self.tree[ptr].clear();
self.tree
.expand_node(ptr, &self.root_position, self.params, self.policy, 1);
self.tree.expand_node(ptr, pos, self.params, self.policy, 1);

let root_eval = self.root_position.get_value_wdl(self.value, self.params);
let root_eval = pos.get_value_wdl(self.value, self.params);
self.tree[ptr].update(1.0 - root_eval);
}
// relabel preexisting root policies with root PST value
else if self.tree[node].has_children() {
self.tree
.relabel_policy(node, &self.root_position, self.params, self.policy, 1);
.relabel_policy(node, pos, self.params, self.policy, 1);

let first_child_ptr = { *self.tree[node].actions() };

Expand All @@ -280,10 +277,10 @@ impl<'a> Searcher<'a> {
continue;
}

let mut position = self.root_position.clone();
position.make_move(self.tree[ptr].parent_move());
let mut child = pos.clone();
child.make_move(self.tree[ptr].parent_move());
self.tree
.relabel_policy(ptr, &position, self.params, self.policy, 2);
.relabel_policy(ptr, &child, self.params, self.policy, 2);
}
}

Expand Down Expand Up @@ -336,124 +333,6 @@ impl<'a> Searcher<'a> {
(mov, q)
}

fn perform_one_iteration(
&self,
pos: &mut ChessState,
ptr: NodePtr,
depth: &mut usize,
) -> Option<f32> {
*depth += 1;

let hash = pos.hash();
let node = &self.tree[ptr];

let mut u = if node.is_terminal() || node.visits() == 0 {
if node.visits() == 0 {
node.set_state(pos.game_state());
}

// probe hash table to use in place of network
if node.state() == GameState::Ongoing {
if let Some(entry) = self.tree.probe_hash(hash) {
entry.q()
} else {
self.get_utility(ptr, pos)
}
} else {
self.get_utility(ptr, pos)
}
} else {
// expand node on the second visit
if node.is_not_expanded() {
self.tree
.expand_node(ptr, pos, self.params, self.policy, *depth)?;
}

// this node has now been accessed so we need to move its
// children across if they are in the other tree half
self.tree.fetch_children(ptr)?;

// select action to take via PUCT
let action = self.pick_action(ptr, node);

let first_child_ptr = { *node.actions() };
let child_ptr = first_child_ptr + action;

let mov = self.tree[child_ptr].parent_move();

pos.make_move(mov);

self.tree[child_ptr].inc_threads();

// acquire lock to avoid issues with desynced setting of
// game state between threads when threads > 1
let lock = if self.tree[child_ptr].visits() == 0 {
Some(node.actions_mut())
} else {
None
};

// descend further
let maybe_u = self.perform_one_iteration(pos, child_ptr, depth);

drop(lock);

self.tree[child_ptr].dec_threads();

let u = maybe_u?;

self.tree
.propogate_proven_mates(ptr, self.tree[child_ptr].state());

u
};

// node scores are stored from the perspective
// **of the parent**, as they are usually only
// accessed from the parent's POV
u = 1.0 - u;

let new_q = node.update(u);
self.tree.push_hash(hash, 1.0 - new_q);

Some(u)
}

fn get_utility(&self, ptr: NodePtr, pos: &ChessState) -> f32 {
match self.tree[ptr].state() {
GameState::Ongoing => pos.get_value_wdl(self.value, self.params),
GameState::Draw => 0.5,
GameState::Lost(_) => 0.0,
GameState::Won(_) => 1.0,
}
}

fn pick_action(&self, ptr: NodePtr, node: &Node) -> usize {
let is_root = ptr == self.tree.root_node();

let cpuct = SearchHelpers::get_cpuct(self.params, node, is_root);
let fpu = SearchHelpers::get_fpu(node);
let expl_scale = SearchHelpers::get_explore_scaling(self.params, node);

let expl = cpuct * expl_scale;

self.tree.get_best_child_by_key(ptr, |child| {
let mut q = SearchHelpers::get_action_value(child, fpu);

// virtual loss
let threads = f64::from(child.threads());
if threads > 0.0 {
let visits = f64::from(child.visits());
let q2 = f64::from(q) * visits / (visits + threads);
q = q2 as f32;
}

let u = expl * child.policy() / (1 + child.visits()) as f32;

q + u
})
}

fn search_report(&self, depth: usize, seldepth: usize, timer: &Instant, nodes: usize) {
print!("info depth {depth} seldepth {seldepth} ");
let (pv_line, score) = self.get_pv(depth);
Expand All @@ -474,7 +353,7 @@ impl<'a> Searcher<'a> {
print!("time {ms} nodes {nodes} nps {nps:.0} pv");

for mov in pv_line {
print!(" {}", self.root_position.conv_mov_to_str(mov));
print!(" {}", self.tree.root_position().conv_mov_to_str(mov));
}

println!();
Expand Down Expand Up @@ -537,7 +416,10 @@ impl<'a> Searcher<'a> {
let first_child_ptr = { *self.tree[self.tree.root_node()].actions() };
for action in 0..self.tree[self.tree.root_node()].num_actions() {
let child = &self.tree[first_child_ptr + action];
let mov = self.root_position.conv_mov_to_str(child.parent_move());
let mov = self
.tree
.root_position()
.conv_mov_to_str(child.parent_move());
let q = child.q() * 100.0;
println!(
"{mov} -> {q:.2}% V({}) S({})",
Expand Down
Loading

0 comments on commit 7c7996e

Please sign in to comment.