1
// This file is part of hnefatafl-copenhagen.
2
//
3
// hnefatafl-copenhagen is free software: you can redistribute it and/or modify
4
// it under the terms of the GNU Affero General Public License as published by
5
// the Free Software Foundation, either version 3 of the License, or
6
// (at your option) any later version.
7
//
8
// hnefatafl-copenhagen is distributed in the hope that it will be useful,
9
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
// GNU Affero General Public License for more details.
12
//
13
// You should have received a copy of the GNU Affero General Public License
14
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15
//
16
// SPDX-License-Identifier: AGPL-3.0-or-later
17
// SPDX-FileCopyrightText: 2025 David Campbell <david@hnefatafl.org>
18

            
19
use std::{fmt, sync::mpsc::channel, time::Duration};
20

            
21
use jiff::Timestamp;
22
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
23
use rustc_hash::FxHashMap;
24

            
25
use crate::{
26
    board::InvalidMove,
27
    game::{EscapeVec, Game},
28
    game_tree::{Node, Tree},
29
    heat_map::HeatMap,
30
    play::Plae,
31
    role::Role,
32
    status::Status,
33
};
34

            
35
pub trait AI: Send {
36
    /// # Errors
37
    ///
38
    /// When the game is already over.
39
    fn generate_move(&mut self, game: &mut Game) -> anyhow::Result<GenerateMove>;
40
    #[allow(clippy::missing_errors_doc)]
41
    fn play(&mut self, game: &mut Game, play: &Plae) -> anyhow::Result<()> {
42
        game.play(play)?;
43
        Ok(())
44
    }
45
}
46

            
47
#[derive(Clone, Debug)]
48
pub struct GenerateMove {
49
    pub play: Plae,
50
    pub score: f64,
51
    pub delay_milliseconds: i64,
52
    pub loops: u64,
53
    pub heat_map: HeatMap,
54
    pub escape_vec: Option<EscapeVec>,
55
}
56

            
57
impl fmt::Display for GenerateMove {
58
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59
        writeln!(
60
            f,
61
            "{}, score: {}, delay milliseconds: {}, loops: {}",
62
            self.play, self.score, self.delay_milliseconds, self.loops
63
        )?;
64

            
65
        if let Some(escape_vec) = &self.escape_vec {
66
            write!(f, "escape_vec:\n\n{escape_vec}")?;
67
        }
68

            
69
        Ok(())
70
    }
71
}
72

            
73
#[derive(Clone, Debug, Default)]
74
pub struct AiBanal;
75

            
76
impl AI for AiBanal {
77
366
    fn generate_move(&mut self, game: &mut Game) -> anyhow::Result<GenerateMove> {
78
366
        if game.status != Status::Ongoing {
79
3
            return Err(InvalidMove::GameOver.into());
80
363
        }
81

            
82
363
        let play = game.all_legal_plays()[0].clone();
83
363
        game.play(&play)?;
84

            
85
363
        Ok(GenerateMove {
86
363
            play,
87
363
            score: 0.0,
88
363
            delay_milliseconds: 0,
89
363
            loops: 0,
90
363
            heat_map: HeatMap::new(game.board.size()),
91
363
            escape_vec: None,
92
363
        })
93
366
    }
94
}
95

            
96
pub struct AiBasic {
97
    depth: u8,
98
    sequential: bool,
99
}
100

            
101
impl AiBasic {
102
    #[must_use]
103
    pub fn new(depth: u8, sequential: bool) -> Self {
104
        Self { depth, sequential }
105
    }
106
}
107

            
108
impl AI for AiBasic {
109
    fn generate_move(&mut self, game: &mut Game) -> anyhow::Result<GenerateMove> {
110
        let t0 = Timestamp::now().as_millisecond();
111

            
112
        if game.status != Status::Ongoing {
113
            return Err(InvalidMove::GameOver.into());
114
        }
115

            
116
        if let Some(play) = game.obvious_play() {
117
            println!("1 turn: {} play: {play}", game.turn);
118

            
119
            game.play(&play)?;
120
            let score = match game.turn {
121
                Role::Attacker => f64::INFINITY,
122
                Role::Defender => -f64::INFINITY,
123
                Role::Roleless => unreachable!(),
124
            };
125

            
126
            let heat_map = HeatMap::from((&*game, &play));
127
            let t1 = Timestamp::now().as_millisecond();
128
            let delay_milliseconds = t1 - t0;
129

            
130
            return Ok(GenerateMove {
131
                play,
132
                score,
133
                delay_milliseconds,
134
                loops: 0,
135
                heat_map,
136
                escape_vec: None,
137
            });
138
        }
139

            
140
        let (play, score, escape_vec) = if self.sequential {
141
            game.alpha_beta(
142
                self.depth as usize,
143
                self.depth,
144
                None,
145
                -f64::INFINITY,
146
                f64::INFINITY,
147
            )
148
        } else {
149
            game.alpha_beta_parallel(
150
                self.depth as usize,
151
                self.depth,
152
                None,
153
                -f64::INFINITY,
154
                f64::INFINITY,
155
            )
156
        };
157

            
158
        let play = match play {
159
            Some(play) => play,
160
            None => match &game.turn {
161
                Role::Attacker => Plae::AttackerResigns,
162
                Role::Defender => Plae::DefenderResigns,
163
                Role::Roleless => unreachable!(),
164
            },
165
        };
166

            
167
        println!("2 turn: {} play: {play}", game.turn);
168
        game.play(&play)?;
169

            
170
        let heat_map = HeatMap::from((&*game, &play));
171

            
172
        let t1 = Timestamp::now().as_millisecond();
173
        let delay_milliseconds = t1 - t0;
174

            
175
        Ok(GenerateMove {
176
            play,
177
            score,
178
            delay_milliseconds,
179
            loops: 0,
180
            heat_map,
181
            escape_vec,
182
        })
183
    }
184
}
185

            
186
#[derive(Clone, Debug)]
187
pub struct AiMonteCarlo {
188
    duration: Duration,
189
    depth: u8,
190
}
191

            
192
impl Default for AiMonteCarlo {
193
    fn default() -> Self {
194
        Self {
195
            duration: Duration::from_secs(1),
196
            depth: 80,
197
        }
198
    }
199
}
200

            
201
impl AI for AiMonteCarlo {
202
    fn generate_move(&mut self, game: &mut Game) -> anyhow::Result<GenerateMove> {
203
        if game.status != Status::Ongoing {
204
            return Err(InvalidMove::GameOver.into());
205
        }
206

            
207
        let t0 = Timestamp::now().as_millisecond();
208
        let mut trees = AiMonteCarlo::make_trees(game)?;
209
        let (tx, rx) = channel();
210

            
211
        trees.par_iter_mut().try_for_each_with(tx, |tx, tree| {
212
            let nodes = tree.monte_carlo_tree_search(self.duration, self.depth);
213
            tx.send(nodes)
214
        })?;
215

            
216
        let mut loops_total = 0;
217
        let mut nodes_master = FxHashMap::default();
218

            
219
        while let Ok((loops, nodes)) = rx.recv() {
220
            loops_total += loops;
221
            for mut node in nodes {
222
                if let Some(Plae::Play(play)) = node.clone().play {
223
                    nodes_master
224
                        .entry(play)
225
                        .and_modify(|node_master: &mut Node| {
226
                            if node_master.count == 0.0 {
227
                                node_master.count = 1.0;
228
                                node_master.score = node.score;
229
                            } else {
230
                                node_master.count += 1.0;
231
                                node_master.score += node.score;
232
                            }
233
                        })
234
                        .or_insert({
235
                            node.count = 1.0;
236
                            node
237
                        });
238
                }
239
            }
240
        }
241

            
242
        for node in nodes_master.values_mut() {
243
            node.score /= node.count;
244
            node.count = 1.0;
245
        }
246

            
247
        let mut nodes: Vec<_> = nodes_master.values().collect();
248
        nodes.sort_by(|a, b| a.score.total_cmp(&b.score));
249

            
250
        let turn = game.turn;
251
        let message = anyhow::Error::msg("The nodes are empty.");
252
        let node = match turn {
253
            Role::Attacker => nodes.last().ok_or(message)?,
254
            Role::Defender => nodes.first().ok_or(message)?,
255
            Role::Roleless => unreachable!(),
256
        };
257

            
258
        let play = node
259
            .play
260
            .as_ref()
261
            .ok_or(anyhow::Error::msg("A move has not been played yet."))?;
262

            
263
        game.play(play)?;
264

            
265
        let here_tree = Tree::from(game.clone());
266
        for tree in &mut trees {
267
            *tree = here_tree.clone();
268
        }
269

            
270
        let t1 = Timestamp::now().as_millisecond();
271
        let delay_milliseconds = t1 - t0;
272
        let heat_map = HeatMap::from(&nodes);
273

            
274
        Ok(GenerateMove {
275
            play: play.clone(),
276
            score: node.score,
277
            delay_milliseconds,
278
            loops: loops_total,
279
            heat_map,
280
            escape_vec: None,
281
        })
282
    }
283
}
284

            
285
impl AiMonteCarlo {
286
    fn make_trees(game: &Game) -> anyhow::Result<Vec<Tree>> {
287
        let count = std::thread::available_parallelism()?.get();
288
        let mut trees = Vec::with_capacity(count);
289

            
290
        for _ in 0..count {
291
            trees.push(Tree::new(game.clone()));
292
        }
293

            
294
        Ok(trees)
295
    }
296

            
297
    #[must_use]
298
    pub fn new(duration: Duration, depth: u8) -> Self {
299
        Self { duration, depth }
300
    }
301
}