1
// This file is part of hnefatafl-copenhagen.
2
//
3
// hnefatafl-copenhagen is free software: you can redistribute it and/or modify
4
// it under the terms of the GNU Affero General Public License as published by
5
// the Free Software Foundation, either version 3 of the License, or
6
// (at your option) any later version.
7
//
8
// hnefatafl-copenhagen is distributed in the hope that it will be useful,
9
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
// GNU Affero General Public License for more details.
12
//
13
// You should have received a copy of the GNU Affero General Public License
14
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15

            
16
#![deny(clippy::expect_used)]
17
#![deny(clippy::indexing_slicing)]
18
#![deny(clippy::panic)]
19
#![deny(clippy::unwrap_used)]
20

            
21
use std::{
22
    io::{BufRead, BufReader, Write},
23
    net::{TcpStream, ToSocketAddrs as _},
24
    time::Duration,
25
};
26

            
27
use anyhow::Error;
28
use clap::{CommandFactory, Parser};
29
use hnefatafl_copenhagen::{
30
    COPYRIGHT, LONG_VERSION, VERSION_ID,
31
    ai::AI,
32
    game::Game,
33
    play::Plae,
34
    role::Role,
35
    status::Status,
36
    utils::{self, choose_ai},
37
};
38
use log::{debug, info, trace};
39
use socket2::{Domain, SockAddr, Socket, TcpKeepalive, Type};
40

            
41
// Move 26, defender wins, corner escape, time per move 15s 2025-03-06 (hnefatafl-equi).
42

            
43
const PORT: &str = ":49152";
44

            
45
/// Copenhagen Hnefatafl AI
46
///
47
/// This is an AI client that connects to a server.
48
#[allow(clippy::struct_excessive_bools)]
49
#[derive(Parser, Debug)]
50
#[command(long_version = LONG_VERSION, about = "Copenhagen Hnefatafl AI")]
51
struct Args {
52
    /// Enter your username
53
    #[arg(long)]
54
    username: String,
55

            
56
    /// Enter your password
57
    #[arg(default_value = "", long)]
58
    password: String,
59

            
60
    /// Set the role as attacker or defender
61
    #[arg(default_value_t = Role::Attacker, long)]
62
    role: Role,
63

            
64
    /// Connect to the HTP server at host
65
    #[arg(default_value = "hnefatafl.org", long)]
66
    host: String,
67

            
68
    /// Choose an AI to play as
69
    #[arg(default_value = "basic", long)]
70
    ai: String,
71

            
72
    /// Whether to log on the debug level
73
    #[arg(long)]
74
    debug: bool,
75

            
76
    /// How many seconds to run the monte-carlo AI
77
    #[arg(long)]
78
    seconds: Option<u64>,
79

            
80
    /// How deep in the game tree to go with Ai
81
    ///
82
    /// [default basic: 4]
83
    /// [default monte-carlo: 20]
84
    #[arg(long)]
85
    depth: Option<u8>,
86

            
87
    /// Join game with id
88
    #[arg(long)]
89
    join_game: Option<u128>,
90

            
91
    /// Run the basic AI sequentially
92
    #[arg(long)]
93
    sequential: bool,
94

            
95
    /// Whether the application is being run by systemd
96
    #[arg(long)]
97
    systemd: bool,
98

            
99
    /// Build the manpage
100
    #[arg(long)]
101
    man: bool,
102
}
103

            
104
fn main() -> anyhow::Result<()> {
105
    let args = Args::parse();
106
    utils::init_logger("hnefatafl_ai", args.debug, args.systemd);
107

            
108
    if args.man {
109
        let mut buffer: Vec<u8> = Vec::default();
110
        let cmd = Args::command().name("hnefatafl-ai").long_version(None);
111
        let man = clap_mangen::Man::new(cmd).date("2025-06-23");
112

            
113
        man.render(&mut buffer)?;
114
        write!(buffer, "{COPYRIGHT}")?;
115

            
116
        std::fs::write("hnefatafl-ai.1", buffer)?;
117
        return Ok(());
118
    }
119

            
120
    let mut username = "ai-".to_string();
121
    username.push_str(&args.username);
122

            
123
    let mut address_string = args.host.clone();
124
    address_string.push_str(PORT);
125

            
126
    let mut is_ipv6 = false;
127
    let mut socket_address = None;
128
    let socket_addresses = address_string.to_socket_addrs()?;
129

            
130
    for address in socket_addresses.clone() {
131
        if address.is_ipv6() {
132
            socket_address = Some(address);
133
            is_ipv6 = true;
134
            break;
135
        }
136
    }
137

            
138
    if !is_ipv6 {
139
        for address in socket_addresses {
140
            if address.is_ipv4() {
141
                socket_address = Some(address);
142
                break;
143
            }
144
        }
145
    }
146

            
147
    let socket_address = socket_address.ok_or_else(|| {
148
        anyhow::Error::msg(format!(
149
            "There is no IP address for the host: {address_string}"
150
        ))
151
    })?;
152

            
153
    let address: SockAddr = socket_address.into();
154
    let keepalive = TcpKeepalive::new()
155
        .with_time(Duration::from_secs(30))
156
        .with_interval(Duration::from_secs(30))
157
        .with_retries(3);
158

            
159
    let domain_type = if is_ipv6 { Domain::IPV6 } else { Domain::IPV4 };
160
    let socket = Socket::new(domain_type, Type::STREAM, None)?;
161
    socket.set_tcp_keepalive(&keepalive)?;
162

            
163
    socket.connect(&address).unwrap_or_else(|error| {
164
        eprintln!("socket.connect {address_string}: {error}");
165
    });
166

            
167
    info!("connected to {socket_address}");
168

            
169
    let mut tcp: TcpStream = socket.into();
170
    let mut reader = BufReader::new(tcp.try_clone()?);
171

            
172
    tcp.write_all(format!("{VERSION_ID} login {username} {}\n", args.password).as_bytes())?;
173

            
174
    let mut buf = String::new();
175
    reader.read_line(&mut buf)?;
176
    assert_eq!(buf, "= login\n");
177
    buf.clear();
178

            
179
    if let Some(game_id) = args.join_game {
180
        tcp.write_all(format!("join_game_pending {game_id}\n").as_bytes())?;
181

            
182
        let ai = choose_ai(&args.ai, args.seconds, args.depth, args.sequential)?;
183
        handle_messages(ai, game_id, &mut reader, &mut tcp)?;
184
    } else {
185
        loop {
186
            new_game(&mut tcp, args.role, &mut reader, &mut buf)?;
187

            
188
            info!("{buf}");
189

            
190
            let message: Vec<_> = buf.split_ascii_whitespace().collect();
191
            let Some(message) = message.get(3) else {
192
                return Err(anyhow::Error::msg("Expecting message[3] to be a game_id"));
193
            };
194

            
195
            let game_id = message.parse()?;
196
            buf.clear();
197

            
198
            wait_for_challenger(&mut reader, &mut buf, &mut tcp, game_id)?;
199

            
200
            let ai = choose_ai(&args.ai, args.seconds, args.depth, args.sequential)?;
201
            handle_messages(ai, game_id, &mut reader, &mut tcp)?;
202
        }
203
    }
204

            
205
    Ok(())
206
}
207

            
208
fn new_game(
209
    tcp: &mut TcpStream,
210
    role: Role,
211
    reader: &mut BufReader<TcpStream>,
212
    buf: &mut String,
213
) -> anyhow::Result<()> {
214
    tcp.write_all(format!("new_game {role} rated fischer 900000 10 11\n").as_bytes())?;
215

            
216
    loop {
217
        reader.read_line(buf)?;
218

            
219
        if buf.trim().is_empty() {
220
            return Err(Error::msg("the TCP stream has closed"));
221
        }
222

            
223
        let message: Vec<_> = buf.split_ascii_whitespace().collect();
224
        if let Some(message) = message.get(1)
225
            && *message == "new_game"
226
        {
227
            return Ok(());
228
        }
229

            
230
        buf.clear();
231
    }
232
}
233

            
234
fn wait_for_challenger(
235
    reader: &mut BufReader<TcpStream>,
236
    buf: &mut String,
237
    tcp: &mut TcpStream,
238
    game_id: u128,
239
) -> anyhow::Result<()> {
240
    loop {
241
        reader.read_line(buf)?;
242

            
243
        if buf.trim().is_empty() {
244
            return Err(Error::msg("the TCP stream has closed"));
245
        }
246

            
247
        let message: Vec<_> = buf.split_ascii_whitespace().collect();
248
        if Some("challenge_requested") == message.get(1).copied() {
249
            info!("{message:?}");
250
            buf.clear();
251

            
252
            break;
253
        }
254

            
255
        buf.clear();
256
    }
257

            
258
    tcp.write_all(format!("join_game {game_id}\n").as_bytes())?;
259
    Ok(())
260
}
261

            
262
fn handle_messages(
263
    mut ai: Box<dyn AI>,
264
    game_id: u128,
265
    reader: &mut BufReader<TcpStream>,
266
    tcp: &mut TcpStream,
267
) -> anyhow::Result<()> {
268
    let mut game = Game::default();
269

            
270
    debug!("{game}\n");
271

            
272
    let mut buf = String::new();
273
    loop {
274
        reader.read_line(&mut buf)?;
275

            
276
        if buf.trim().is_empty() {
277
            return Err(Error::msg("the TCP stream has closed"));
278
        }
279

            
280
        let mut message: Vec<_> = buf.split_ascii_whitespace().collect();
281

            
282
        if Some("generate_move") == message.get(2).copied() {
283
            let generate_move = ai.generate_move(&mut game)?;
284

            
285
            tcp.write_all(format!("game {game_id} {}\n", generate_move.play).as_bytes())?;
286

            
287
            debug!("{game}");
288
            info!("{generate_move}");
289
            trace!("{}", generate_move.heat_map);
290

            
291
            if game.status != Status::Ongoing {
292
                return Ok(());
293
            }
294
        } else if Some("play") == message.get(2).copied() {
295
            let words = message.split_off(2);
296
            let play = Plae::try_from(words)?;
297
            ai.play(&mut game, &play)?;
298

            
299
            debug!("{game}\n");
300

            
301
            if game.status != Status::Ongoing {
302
                return Ok(());
303
            }
304
        } else if Some("game_over") == message.get(1).copied() {
305
            return Ok(());
306
        }
307

            
308
        buf.clear();
309
    }
310
}