1
// This file is part of hnefatafl-copenhagen.
2
//
3
// hnefatafl-copenhagen is free software: you can redistribute it and/or modify
4
// it under the terms of the GNU Affero General Public License as published by
5
// the Free Software Foundation, either version 3 of the License, or
6
// (at your option) any later version.
7
//
8
// hnefatafl-copenhagen is distributed in the hope that it will be useful,
9
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
// GNU Affero General Public License for more details.
12
//
13
// You should have received a copy of the GNU Affero General Public License
14
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15
//
16
// SPDX-License-Identifier: AGPL-3.0-or-later
17
// SPDX-FileCopyrightText: 2025 David Campbell <david@hnefatafl.org>
18

            
19
#![deny(clippy::expect_used)]
20
#![deny(clippy::indexing_slicing)]
21
#![deny(clippy::panic)]
22
#![deny(clippy::unwrap_used)]
23

            
24
use std::{
25
    io::{BufRead, BufReader, Write},
26
    net::{TcpStream, ToSocketAddrs as _},
27
    process::Command,
28
    thread::sleep,
29
    time::Duration,
30
};
31

            
32
use anyhow::Error;
33
use clap::{CommandFactory, Parser};
34
use hnefatafl_copenhagen::{
35
    COPYRIGHT, LONG_VERSION, VERSION_ID,
36
    ai::AI,
37
    game::Game,
38
    play::Plae,
39
    role::Role,
40
    status::Status,
41
    tcp_keep_alive,
42
    utils::{self, choose_ai},
43
};
44
use log::{debug, error, info, trace};
45
use socket2::{Domain, SockAddr, Socket, Type};
46

            
47
// Move 26, defender wins, corner escape, time per move 15s 2025-03-06 (hnefatafl-equi).
48

            
49
const PORT: &str = ":49152";
50

            
51
/// Copenhagen Hnefatafl AI
52
///
53
/// This is an AI client that connects to a server.
54
#[allow(clippy::struct_excessive_bools)]
55
#[derive(Parser, Debug)]
56
#[command(long_version = LONG_VERSION, about = "Copenhagen Hnefatafl AI")]
57
struct Args {
58
    /// Enter your username
59
    #[arg(long)]
60
    username: String,
61

            
62
    /// Enter your password
63
    #[arg(default_value = "", long)]
64
    password: String,
65

            
66
    /// Set the role as attacker or defender
67
    #[arg(default_value_t = Role::Attacker, long)]
68
    role: Role,
69

            
70
    /// Connect to the HTP server at host
71
    #[arg(default_value = "hnefatafl.org", long)]
72
    host: String,
73

            
74
    /// Choose an AI to play as
75
    #[arg(default_value = "basic", long)]
76
    ai: String,
77

            
78
    /// Whether to log on the debug level
79
    #[arg(long)]
80
    debug: bool,
81

            
82
    /// How many seconds to run the monte-carlo AI
83
    #[arg(long)]
84
    seconds: Option<u64>,
85

            
86
    /// How deep in the game tree to go with Ai
87
    ///
88
    /// [default basic: 4]
89
    /// [default monte-carlo: 20]
90
    #[arg(long)]
91
    depth: Option<u8>,
92

            
93
    /// Join game with id
94
    #[arg(long)]
95
    join_game: Option<u128>,
96

            
97
    /// Run the basic AI sequentially
98
    #[arg(long)]
99
    sequential: bool,
100

            
101
    /// Whether the application is being run by systemd
102
    #[arg(long)]
103
    systemd: bool,
104

            
105
    /// Build the manpage
106
    #[arg(long)]
107
    man: bool,
108
}
109

            
110
fn main() -> anyhow::Result<()> {
111
    let args = Args::parse();
112
    utils::init_logger("hnefatafl_ai", args.debug, args.systemd);
113

            
114
    if args.man {
115
        let mut buffer: Vec<u8> = Vec::default();
116
        let cmd = Args::command().name("hnefatafl-ai").long_version(None);
117
        let man = clap_mangen::Man::new(cmd).date("2025-06-23");
118

            
119
        man.render(&mut buffer)?;
120
        write!(buffer, "{COPYRIGHT}")?;
121

            
122
        std::fs::write("hnefatafl-ai.1", buffer)?;
123
        return Ok(());
124
    }
125

            
126
    let mut username = "ai-".to_string();
127
    username.push_str(&args.username);
128

            
129
    let mut address_string = args.host.clone();
130
    address_string.push_str(PORT);
131

            
132
    let mut is_ipv6 = false;
133
    let mut socket_address = None;
134
    let socket_addresses = address_string.to_socket_addrs()?;
135

            
136
    for address in socket_addresses.clone() {
137
        if address.is_ipv6() {
138
            socket_address = Some(address);
139
            is_ipv6 = true;
140
            break;
141
        }
142
    }
143

            
144
    if !is_ipv6 {
145
        for address in socket_addresses {
146
            if address.is_ipv4() {
147
                socket_address = Some(address);
148
                break;
149
            }
150
        }
151
    }
152

            
153
    let socket_address = socket_address.ok_or_else(|| {
154
        anyhow::Error::msg(format!(
155
            "There is no IP address for the host: {address_string}"
156
        ))
157
    })?;
158

            
159
    let address: SockAddr = socket_address.into();
160
    let keep_alive = tcp_keep_alive();
161
    let domain_type = if is_ipv6 { Domain::IPV6 } else { Domain::IPV4 };
162
    let socket = Socket::new(domain_type, Type::STREAM, None)?;
163
    socket.set_tcp_keepalive(&keep_alive)?;
164

            
165
    systemd_delay_restart(&args)?;
166

            
167
    if let Err(error) = socket.connect(&address) {
168
        error!("socket.connect {address_string}: failed");
169
        return Err(error.into());
170
    }
171

            
172
    info!("connected to {socket_address}");
173

            
174
    let mut tcp: TcpStream = socket.into();
175
    let mut reader = BufReader::new(tcp.try_clone()?);
176

            
177
    tcp.write_all(format!("{VERSION_ID} login {username} {}\n", args.password).as_bytes())?;
178

            
179
    let mut buf = String::new();
180
    reader.read_line(&mut buf)?;
181
    assert_eq!(buf, "= login\n");
182
    buf.clear();
183

            
184
    if let Some(game_id) = args.join_game {
185
        tcp.write_all(format!("join_game_pending {game_id}\n").as_bytes())?;
186

            
187
        let ai = choose_ai(&args.ai, args.seconds, args.depth, args.sequential)?;
188
        handle_messages(ai, game_id, &mut reader, &mut tcp)?;
189
    } else {
190
        loop {
191
            new_game(&mut tcp, args.role, &mut reader, &mut buf)?;
192

            
193
            info!("{buf}");
194

            
195
            let message: Vec<_> = buf.split_ascii_whitespace().collect();
196
            let Some(message) = message.get(3) else {
197
                return Err(anyhow::Error::msg("Expecting message[3] to be a game_id"));
198
            };
199

            
200
            let game_id = message.parse()?;
201
            buf.clear();
202

            
203
            wait_for_challenger(&mut reader, &mut buf, &mut tcp, game_id)?;
204

            
205
            let ai = choose_ai(&args.ai, args.seconds, args.depth, args.sequential)?;
206
            handle_messages(ai, game_id, &mut reader, &mut tcp)?;
207
        }
208
    }
209

            
210
    Ok(())
211
}
212

            
213
fn new_game(
214
    tcp: &mut TcpStream,
215
    role: Role,
216
    reader: &mut BufReader<TcpStream>,
217
    buf: &mut String,
218
) -> anyhow::Result<()> {
219
    tcp.write_all(format!("new_game {role} rated fischer 900000 10 11\n").as_bytes())?;
220

            
221
    loop {
222
        reader.read_line(buf)?;
223

            
224
        if buf.trim().is_empty() {
225
            return Err(Error::msg("the TCP stream has closed"));
226
        }
227

            
228
        let message: Vec<_> = buf.split_ascii_whitespace().collect();
229
        if let Some(message) = message.get(1)
230
            && *message == "new_game"
231
        {
232
            return Ok(());
233
        }
234

            
235
        buf.clear();
236
    }
237
}
238

            
239
fn wait_for_challenger(
240
    reader: &mut BufReader<TcpStream>,
241
    buf: &mut String,
242
    tcp: &mut TcpStream,
243
    game_id: u128,
244
) -> anyhow::Result<()> {
245
    loop {
246
        reader.read_line(buf)?;
247

            
248
        if buf.trim().is_empty() {
249
            return Err(Error::msg("the TCP stream has closed"));
250
        }
251

            
252
        let message: Vec<_> = buf.split_ascii_whitespace().collect();
253
        if Some("challenge_requested") == message.get(1).copied() {
254
            info!("{message:?}");
255
            buf.clear();
256

            
257
            break;
258
        }
259

            
260
        buf.clear();
261
    }
262

            
263
    tcp.write_all(format!("join_game {game_id}\n").as_bytes())?;
264
    Ok(())
265
}
266

            
267
fn handle_messages(
268
    mut ai: Box<dyn AI>,
269
    game_id: u128,
270
    reader: &mut BufReader<TcpStream>,
271
    tcp: &mut TcpStream,
272
) -> anyhow::Result<()> {
273
    let mut game = Game::default();
274

            
275
    debug!("{game}\n");
276

            
277
    let mut buf = String::new();
278
    loop {
279
        reader.read_line(&mut buf)?;
280

            
281
        if buf.trim().is_empty() {
282
            return Err(Error::msg("the TCP stream has closed"));
283
        }
284

            
285
        let mut message: Vec<_> = buf.split_ascii_whitespace().collect();
286

            
287
        if Some("generate_move") == message.get(2).copied() {
288
            let generate_move = ai.generate_move(&mut game)?;
289

            
290
            tcp.write_all(format!("game {game_id} {}\n", generate_move.play).as_bytes())?;
291

            
292
            debug!("{game}");
293
            info!("{generate_move}");
294
            trace!("{}", generate_move.heat_map);
295

            
296
            if game.status != Status::Ongoing {
297
                return Ok(());
298
            }
299
        } else if Some("play") == message.get(2).copied() {
300
            let words = message.split_off(2);
301
            let play = Plae::try_from(words)?;
302
            ai.play(&mut game, &play)?;
303

            
304
            debug!("{game}\n");
305

            
306
            if game.status != Status::Ongoing {
307
                return Ok(());
308
            }
309
        } else if Some("game_over") == message.get(1).copied() {
310
            return Ok(());
311
        }
312

            
313
        buf.clear();
314
    }
315
}
316

            
317
fn systemd_delay_restart(args: &Args) -> anyhow::Result<()> {
318
    if args.systemd {
319
        let service = match args.role {
320
            Role::Attacker => "hnefatafl-ai-attacker.service",
321
            Role::Defender => "hnefatafl-ai-defender.service",
322
            Role::Roleless => unreachable!(),
323
        };
324

            
325
        let output = Command::new("systemctl")
326
            .args(["show", service, "-p", "NRestarts"])
327
            .output()?;
328

            
329
        let i = String::from_utf8_lossy(&output.stdout)
330
            .replace("NRestarts=", "")
331
            .trim()
332
            .parse()?;
333

            
334
        if i > 0 {
335
            let delay = 2u64.pow(i);
336
            log::info!("sleeping for {delay}s...");
337
            sleep(Duration::from_secs(delay));
338
        }
339
    }
340

            
341
    Ok(())
342
}