1
// This file is part of hnefatafl-copenhagen.
2
//
3
// hnefatafl-copenhagen is free software: you can redistribute it and/or modify
4
// it under the terms of the GNU Affero General Public License as published by
5
// the Free Software Foundation, either version 3 of the License, or
6
// (at your option) any later version.
7
//
8
// hnefatafl-copenhagen is distributed in the hope that it will be useful,
9
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
// GNU Affero General Public License for more details.
12
//
13
// You should have received a copy of the GNU Affero General Public License
14
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15
//
16
// SPDX-License-Identifier: AGPL-3.0-or-later
17
// SPDX-FileCopyrightText: 2025 David Campbell <david@hnefatafl.org>
18

            
19
use std::{
20
    fmt,
21
    time::{Duration, Instant},
22
};
23

            
24
use rand::prelude::*;
25
use rustc_hash::FxHashMap;
26
use serde::{Deserialize, Serialize};
27

            
28
use crate::{
29
    board::BoardSize,
30
    game::Game,
31
    play::{Plae, Plays},
32
    status::Status,
33
};
34

            
35
#[derive(Clone, Debug)]
36
pub struct Tree {
37
    here: u64,
38
    pub game: Game,
39
    arena: FxHashMap<u64, Node>,
40
}
41

            
42
impl Tree {
43
15192
    fn insert_child(&mut self, child_index: u64, parent_index: u64, play: Plae) {
44
15192
        let node = self.arena.get_mut(&parent_index).unwrap_or_else(|| {
45
            println!("The hashmap should have the node {parent_index}.");
46
            unreachable!();
47
        });
48

            
49
15192
        node.children.push(child_index);
50
15192
        let board_size = node.board_size;
51

            
52
15192
        self.arena.insert(
53
15192
            child_index,
54
15192
            Node {
55
15192
                board_size,
56
15192
                play: Some(play),
57
15192
                score: 0.0,
58
15192
                count: 1.0,
59
15192
                parent: Some(parent_index),
60
15192
                children: Vec::new(),
61
15192
            },
62
        );
63
15192
    }
64

            
65
    #[allow(clippy::expect_used)]
66
    #[allow(clippy::missing_panics_doc)]
67
    #[must_use]
68
3
    pub fn monte_carlo_tree_search(&mut self, duration: Duration, depth: u8) -> (u64, Vec<Node>) {
69
        // Doesn't seem to do much... and makes the overall search slower, about 5%.
70
        /*
71
        if self.game.previous_boards.0.len() == 1 {
72

            
73
        }
74
        */
75

            
76
3
        let t0 = Instant::now();
77
3
        let mut rng = rand::rng();
78
3
        let mut loops = 0;
79

            
80
        loop {
81
197
            let t1 = Instant::now();
82
197
            let elapsed_time = t1 - t0;
83

            
84
197
            if duration < elapsed_time {
85
3
                break;
86
194
            }
87
194
            loops += 1;
88

            
89
194
            let mut game = self.game.clone();
90
194
            let mut here = self.here;
91

            
92
194
            for _ in 0..depth {
93
15232
                let play = if let Some(play) = game.obvious_play() {
94
19
                    game.play(&play).expect("The play should be legal!");
95
19
                    play
96
                } else {
97
15213
                    let plays = game.all_legal_plays();
98
15213
                    let index = rng.random_range(0..plays.len());
99
15213
                    let play = plays[index].clone();
100
15213
                    game.play(&play).expect("The play should be legal!");
101
15213
                    play
102
                };
103

            
104
15232
                let child_index = game.calculate_hash();
105
15232
                if let Some(node) = self.arena.get_mut(&child_index) {
106
40
                    node.count += 1.0;
107
15192
                } else {
108
15192
                    self.insert_child(child_index, here, play);
109
15192
                }
110
15232
                here = child_index;
111

            
112
15232
                let gamma = 0.95;
113

            
114
15232
                match game.status {
115
                    Status::AttackerWins => {
116
4
                        let node = self
117
4
                            .arena
118
4
                            .get_mut(&here)
119
4
                            .expect("The hashmap should have the node.");
120

            
121
4
                        node.score += 1.0;
122
4
                        let mut g = 1.0;
123

            
124
284
                        while let Some(node) = self.arena[&here].parent {
125
280
                            let real_node =
126
280
                                self.arena.get_mut(&node).expect("The node should exist!");
127
280

            
128
280
                            g *= gamma;
129
280
                            real_node.score += g;
130
280
                            here = node;
131
280
                        }
132

            
133
4
                        break;
134
                    }
135
                    Status::DefenderWins => {
136
15
                        let node = self
137
15
                            .arena
138
15
                            .get_mut(&here)
139
15
                            .expect("The hashmap should have the node.");
140

            
141
15
                        node.score -= 1.0;
142
15
                        let mut g = -1.0;
143

            
144
967
                        while let Some(node) = self.arena[&here].parent {
145
952
                            let real_node =
146
952
                                self.arena.get_mut(&node).expect("The node should exist!");
147
952

            
148
952
                            g *= gamma;
149
952
                            real_node.score += g;
150
952
                            here = node;
151
952
                        }
152

            
153
15
                        break;
154
                    }
155
                    Status::Draw => unreachable!(),
156
15213
                    Status::Ongoing => {
157
15213
                        // Keep going.
158
15213
                    }
159
                }
160
            }
161
        }
162

            
163
15195
        for node in self.arena.values_mut() {
164
15195
            node.score /= node.count;
165
15195
            node.count = 1.0;
166
15195
        }
167

            
168
3
        let children = &self.arena[&self.here].children;
169
        (
170
3
            loops,
171
3
            children
172
3
                .iter()
173
154
                .map(|child| self.arena[child].clone())
174
3
                .collect::<Vec<_>>(),
175
        )
176
3
    }
177

            
178
    #[must_use]
179
3
    pub fn new(game: Game) -> Self {
180
3
        let hash = game.calculate_hash();
181
3
        let mut arena = FxHashMap::default();
182
3
        arena.insert(
183
3
            hash,
184
3
            Node {
185
3
                board_size: game.board.size(),
186
3
                play: None,
187
3
                score: 0.0,
188
3
                count: 0.0,
189
3
                parent: None,
190
3
                children: Vec::new(),
191
3
            },
192
        );
193

            
194
3
        Self {
195
3
            here: hash,
196
3
            game,
197
3
            arena,
198
3
        }
199
3
    }
200
}
201

            
202
impl From<Game> for Tree {
203
    fn from(game: Game) -> Self {
204
        let mut arena = FxHashMap::default();
205

            
206
        let play = match &game.plays {
207
            Plays::PlayRecords(plays) => {
208
                if let Some(play) = plays.last() {
209
                    play.clone()
210
                } else {
211
                    None
212
                }
213
            }
214
            Plays::PlayRecordsTimed(plays) => {
215
                if let Some(timing) = plays.last() {
216
                    timing.play.clone()
217
                } else {
218
                    None
219
                }
220
            }
221
        };
222

            
223
        let hash = game.calculate_hash();
224
        arena.insert(
225
            hash,
226
            Node {
227
                board_size: game.board.size(),
228
                play: play.clone(),
229
                score: 0.0,
230
                count: 0.0,
231
                parent: None,
232
                children: Vec::new(),
233
            },
234
        );
235

            
236
        Self {
237
            here: hash,
238
            game,
239
            arena,
240
        }
241
    }
242
}
243

            
244
#[derive(Clone, Debug, Deserialize, Serialize)]
245
pub struct Node {
246
    pub board_size: BoardSize,
247
    pub play: Option<Plae>,
248
    pub score: f64,
249
    pub count: f64,
250
    parent: Option<u64>,
251
    children: Vec<u64>,
252
}
253

            
254
impl fmt::Display for Node {
255
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256
        if let Some(play) = &self.play {
257
            write!(
258
                f,
259
                "play: {play}, score: {}, count: {}",
260
                self.score, self.count
261
            )
262
        } else {
263
            write!(f, "play: None")
264
        }
265
    }
266
}