1
// This file is part of hnefatafl-copenhagen.
2
//
3
// hnefatafl-copenhagen is free software: you can redistribute it and/or modify
4
// it under the terms of the GNU Affero General Public License as published by
5
// the Free Software Foundation, either version 3 of the License, or
6
// (at your option) any later version.
7
//
8
// hnefatafl-copenhagen is distributed in the hope that it will be useful,
9
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
// GNU Affero General Public License for more details.
12
//
13
// You should have received a copy of the GNU Affero General Public License
14
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15

            
16
use std::{
17
    fmt,
18
    time::{Duration, Instant},
19
};
20

            
21
use rand::prelude::*;
22
use rustc_hash::FxHashMap;
23
use serde::{Deserialize, Serialize};
24

            
25
use crate::{
26
    board::BoardSize,
27
    game::Game,
28
    play::{Plae, Plays},
29
    status::Status,
30
};
31

            
32
#[derive(Clone, Debug)]
33
pub struct Tree {
34
    here: u64,
35
    pub game: Game,
36
    arena: FxHashMap<u64, Node>,
37
}
38

            
39
impl Tree {
40
29650
    fn insert_child(&mut self, child_index: u64, parent_index: u64, play: Plae) {
41
29650
        let node = self.arena.get_mut(&parent_index).unwrap_or_else(|| {
42
            println!("The hashmap should have the node {parent_index}.");
43
            unreachable!();
44
        });
45

            
46
29650
        node.children.push(child_index);
47
29650
        let board_size = node.board_size;
48

            
49
29650
        self.arena.insert(
50
29650
            child_index,
51
29650
            Node {
52
29650
                board_size,
53
29650
                play: Some(play),
54
29650
                score: 0.0,
55
29650
                count: 1.0,
56
29650
                parent: Some(parent_index),
57
29650
                children: Vec::new(),
58
29650
            },
59
        );
60
29650
    }
61

            
62
    #[allow(clippy::expect_used)]
63
    #[allow(clippy::missing_panics_doc)]
64
    #[must_use]
65
6
    pub fn monte_carlo_tree_search(&mut self, duration: Duration, depth: u8) -> (u64, Vec<Node>) {
66
        // Doesn't seem to do much... and makes the overall search slower, about 5%.
67
        /*
68
        if self.game.previous_boards.0.len() == 1 {
69

            
70
        }
71
        */
72

            
73
6
        let t0 = Instant::now();
74
6
        let mut rng = rand::rng();
75
6
        let mut loops = 0;
76

            
77
        loop {
78
383
            let t1 = Instant::now();
79
383
            let elapsed_time = t1 - t0;
80

            
81
383
            if duration < elapsed_time {
82
6
                break;
83
377
            }
84
377
            loops += 1;
85

            
86
377
            let mut game = self.game.clone();
87
377
            let mut here = self.here;
88

            
89
377
            for _ in 0..depth {
90
29741
                let play = if let Some(play) = game.obvious_play() {
91
25
                    game.play(&play).expect("The play should be legal!");
92
25
                    play
93
                } else {
94
29716
                    let plays = game.all_legal_plays();
95
29716
                    let index = rng.random_range(0..plays.len());
96
29716
                    let play = plays[index].clone();
97
29716
                    game.play(&play).expect("The play should be legal!");
98
29716
                    play
99
                };
100

            
101
29741
                let child_index = game.calculate_hash();
102
29741
                if let Some(node) = self.arena.get_mut(&child_index) {
103
91
                    node.count += 1.0;
104
29650
                } else {
105
29650
                    self.insert_child(child_index, here, play);
106
29650
                }
107
29741
                here = child_index;
108

            
109
29741
                let gamma = 0.95;
110

            
111
29741
                match game.status {
112
                    Status::AttackerWins => {
113
7
                        let node = self
114
7
                            .arena
115
7
                            .get_mut(&here)
116
7
                            .expect("The hashmap should have the node.");
117

            
118
7
                        node.score += 1.0;
119
7
                        let mut g = 1.0;
120

            
121
440
                        while let Some(node) = self.arena[&here].parent {
122
433
                            let real_node =
123
433
                                self.arena.get_mut(&node).expect("The node should exist!");
124
433

            
125
433
                            g *= gamma;
126
433
                            real_node.score += g;
127
433
                            here = node;
128
433
                        }
129

            
130
7
                        break;
131
                    }
132
                    Status::DefenderWins => {
133
18
                        let node = self
134
18
                            .arena
135
18
                            .get_mut(&here)
136
18
                            .expect("The hashmap should have the node.");
137

            
138
18
                        node.score -= 1.0;
139
18
                        let mut g = -1.0;
140

            
141
1166
                        while let Some(node) = self.arena[&here].parent {
142
1148
                            let real_node =
143
1148
                                self.arena.get_mut(&node).expect("The node should exist!");
144
1148

            
145
1148
                            g *= gamma;
146
1148
                            real_node.score += g;
147
1148
                            here = node;
148
1148
                        }
149

            
150
18
                        break;
151
                    }
152
                    Status::Draw => unreachable!(),
153
29716
                    Status::Ongoing => {
154
29716
                        // Keep going.
155
29716
                    }
156
                }
157
            }
158
        }
159

            
160
29656
        for node in self.arena.values_mut() {
161
29656
            node.score /= node.count;
162
29656
            node.count = 1.0;
163
29656
        }
164

            
165
6
        let children = &self.arena[&self.here].children;
166
        (
167
6
            loops,
168
6
            children
169
6
                .iter()
170
288
                .map(|child| self.arena[child].clone())
171
6
                .collect::<Vec<_>>(),
172
        )
173
6
    }
174

            
175
    #[must_use]
176
6
    pub fn new(game: Game) -> Self {
177
6
        let hash = game.calculate_hash();
178
6
        let mut arena = FxHashMap::default();
179
6
        arena.insert(
180
6
            hash,
181
6
            Node {
182
6
                board_size: game.board.size(),
183
6
                play: None,
184
6
                score: 0.0,
185
6
                count: 0.0,
186
6
                parent: None,
187
6
                children: Vec::new(),
188
6
            },
189
        );
190

            
191
6
        Self {
192
6
            here: hash,
193
6
            game,
194
6
            arena,
195
6
        }
196
6
    }
197
}
198

            
199
impl From<Game> for Tree {
200
    fn from(game: Game) -> Self {
201
        let mut arena = FxHashMap::default();
202

            
203
        let play = match &game.plays {
204
            Plays::PlayRecords(plays) => {
205
                if let Some(play) = plays.last() {
206
                    play.clone()
207
                } else {
208
                    None
209
                }
210
            }
211
            Plays::PlayRecordsTimed(plays) => {
212
                if let Some(timing) = plays.last() {
213
                    timing.play.clone()
214
                } else {
215
                    None
216
                }
217
            }
218
        };
219

            
220
        let hash = game.calculate_hash();
221
        arena.insert(
222
            hash,
223
            Node {
224
                board_size: game.board.size(),
225
                play: play.clone(),
226
                score: 0.0,
227
                count: 0.0,
228
                parent: None,
229
                children: Vec::new(),
230
            },
231
        );
232

            
233
        Self {
234
            here: hash,
235
            game,
236
            arena,
237
        }
238
    }
239
}
240

            
241
#[derive(Clone, Debug, Deserialize, Serialize)]
242
pub struct Node {
243
    pub board_size: BoardSize,
244
    pub play: Option<Plae>,
245
    pub score: f64,
246
    pub count: f64,
247
    parent: Option<u64>,
248
    children: Vec<u64>,
249
}
250

            
251
impl fmt::Display for Node {
252
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253
        if let Some(play) = &self.play {
254
            write!(
255
                f,
256
                "play: {play}, score: {}, count: {}",
257
                self.score, self.count
258
            )
259
        } else {
260
            write!(f, "play: None")
261
        }
262
    }
263
}