1
use std::io::Read;
2
use std::io::Write;
3
use std::io::{self};
4

            
5
use bitstream_io::BigEndian;
6
use bitstream_io::BitRead;
7
use bitstream_io::BitReader;
8
use bitstream_io::BitWrite;
9
use bitstream_io::BitWriter;
10
use log::error;
11

            
12
use merc_number::read_u64_variablelength;
13
use merc_number::write_u64_variablelength;
14
use merc_utilities::MercError;
15

            
16
/// Trait for writing bit-level data.
17
pub trait BitStreamWrite {
18
    /// Writes the least significant bits from a u64 value.
19
    ///
20
    /// # Preconditions
21
    /// - number_of_bits must be <= 64
22
    fn write_bits(&mut self, value: u64, number_of_bits: u8) -> Result<(), MercError>;
23

            
24
    /// Writes a string prefixed with its length as a variable-width integer.
25
    fn write_string(&mut self, s: &str) -> Result<(), MercError>;
26

            
27
    /// Writes a u64 value using variable-width encoding.
28
    fn write_integer(&mut self, value: u64) -> Result<(), MercError>;
29

            
30
    /// Flushes any remaining bits to the underlying writer.
31
    fn flush(&mut self) -> Result<(), MercError>;
32
}
33

            
34
/// Trait for reading bit-level data.
35
pub trait BitStreamRead {
36
    /// Reads bits into the least significant bits of a u64.
37
    ///
38
    /// # Preconditions
39
    /// - number_of_bits must be <= 64
40
    fn read_bits(&mut self, number_of_bits: u8) -> Result<u64, MercError>;
41

            
42
    /// Reads a length-prefixed string.
43
    fn read_string(&mut self) -> Result<String, MercError>;
44

            
45
    /// Reads a variable-width encoded integer.
46
    fn read_integer(&mut self) -> Result<u64, MercError>;
47
}
48

            
49
/// Writer for bit-level output operations using an underlying writer.
50
pub struct BitStreamWriter<W: Write> {
51
    writer: BitWriter<W, BigEndian>,
52
}
53

            
54
impl<W: Write> BitStreamWriter<W> {
55
    /// Creates a new BitStreamWriter wrapping the provided writer.
56
500
    pub fn new(writer: W) -> Self {
57
500
        Self {
58
500
            writer: BitWriter::new(writer),
59
500
        }
60
500
    }
61
}
62

            
63
impl<W: Write> Drop for BitStreamWriter<W> {
64
500
    fn drop(&mut self) {
65
500
        if self.flush().is_err() {
66
            error!("Panicked while flushing the stream when dropped!")
67
500
        }
68
500
    }
69
}
70

            
71
/// Reader for bit-level input operations from an underlying reader.
72
pub struct BitStreamReader<R: Read> {
73
    reader: BitReader<R, BigEndian>,
74
    text_buffer: Vec<u8>,
75
}
76

            
77
impl<R: Read> BitStreamReader<R> {
78
    /// Creates a new BitStreamReader wrapping the provided reader.
79
502
    pub fn new(reader: R) -> Self {
80
502
        Self {
81
502
            reader: BitReader::new(reader),
82
502
            text_buffer: Vec::with_capacity(128),
83
502
        }
84
502
    }
85
}
86

            
87
impl<W: Write> BitStreamWrite for BitStreamWriter<W> {
88
2121803
    fn write_bits(&mut self, value: u64, number_of_bits: u8) -> Result<(), MercError> {
89
2121803
        debug_assert!(number_of_bits <= 64);
90
2121803
        Ok(self.writer.write_var(number_of_bits as u32, value)?)
91
2121803
    }
92

            
93
7225
    fn write_string(&mut self, s: &str) -> Result<(), MercError> {
94
7225
        self.write_integer(s.len() as u64)?;
95
50375
        for byte in s.as_bytes() {
96
50375
            self.writer.write::<8, u64>(*byte as u64)?;
97
        }
98
7225
        Ok(())
99
7225
    }
100

            
101
650802
    fn write_integer(&mut self, value: u64) -> Result<(), MercError> {
102
650802
        write_u64_variablelength(&mut self.writer, value)?;
103
650802
        Ok(())
104
650802
    }
105

            
106
900
    fn flush(&mut self) -> Result<(), MercError> {
107
900
        self.writer.byte_align()?;
108
900
        Ok(self.writer.flush()?)
109
900
    }
110
}
111

            
112
impl<R: Read> BitStreamRead for BitStreamReader<R> {
113
2161985
    fn read_bits(&mut self, number_of_bits: u8) -> Result<u64, MercError> {
114
2161985
        assert!(number_of_bits <= 64);
115
2161985
        Ok(self.reader.read_var(number_of_bits as u32)?)
116
2161985
    }
117

            
118
7445
    fn read_string(&mut self) -> Result<String, MercError> {
119
7445
        let length = self.read_integer()?;
120
7445
        self.text_buffer.clear();
121
7445
        self.text_buffer
122
7445
            .reserve(length.try_into().expect("String size exceeds usize!"));
123

            
124
7445
        for _ in 0..length {
125
53378
            let byte = self.reader.read::<8, u8>()?;
126
53378
            self.text_buffer.push(byte);
127
        }
128

            
129
7445
        Ok(String::from_utf8(self.text_buffer.clone()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?)
130
7445
    }
131

            
132
661364
    fn read_integer(&mut self) -> Result<u64, MercError> {
133
661364
        read_u64_variablelength(&mut self.reader)
134
661364
    }
135
}
136

            
137
#[cfg(test)]
138
mod tests {
139
    use log::debug;
140
    use merc_utilities::random_test;
141
    use rand::Rng;
142
    use rand::distr::Alphanumeric;
143

            
144
    use super::*;
145

            
146
    /// Decide (arbitrarily) what to write into the bitstream.
147
    #[derive(Debug)]
148
    enum Instruction {
149
        String(String),
150
        Integer(u64),
151
        /// (value, num_of_bits), where num_of_bits must be at most 64.
152
        Bits(u64, u8),
153
    }
154

            
155
    /// Calculate minimum bits needed to represent the value
156
    /// Use 1 bit if value is 0 to ensure at least 1 bit is written
157
    pub fn required_bits(value: u64) -> u8 {
158
        if value == 0 {
159
            1
160
        } else {
161
            64 - value.leading_zeros() as u8
162
        }
163
    }
164

            
165
    #[test]
166
1
    fn test_arbitrary_bitstream() {
167
100
        random_test(100, |rng| {
168
100
            let instructions: Vec<Instruction> = (0..100)
169
10000
                .map(|_| match rng.random_range(0..2) {
170
                    0 => {
171
5025
                        let string = rng.sample_iter(&Alphanumeric).take(7).map(char::from).collect();
172
5025
                        Instruction::String(string)
173
                    }
174
4975
                    1 => Instruction::Integer(rng.random()),
175
                    2 => {
176
                        let value: u64 = rng.random();
177
                        Instruction::Bits(value, required_bits(value))
178
                    }
179
                    _ => unreachable!("The range is from 0 to 2"),
180
10000
                })
181
100
                .collect();
182

            
183
100
            let mut buffer = Vec::new();
184
            {
185
100
                let mut writer = BitStreamWriter::new(&mut buffer);
186

            
187
10000
                for inst in &instructions {
188
10000
                    debug!("Writing {inst:?}");
189
10000
                    match inst {
190
5025
                        Instruction::String(string) => {
191
5025
                            writer.write_string(string).expect("Failed to write into stream")
192
                        }
193
4975
                        Instruction::Integer(value) => {
194
4975
                            writer.write_integer(*value).expect("Failed to write into stream")
195
                        }
196
                        Instruction::Bits(value, number_of_bits) => writer
197
                            .write_bits(*value, *number_of_bits)
198
                            .expect("Failed to write into stream"),
199
                    }
200
                }
201

            
202
100
                writer.flush().expect("Failed to write into stream");
203
            }
204

            
205
100
            let mut reader = BitStreamReader::new(&buffer[..]);
206

            
207
10000
            for inst in &instructions {
208
10000
                debug!("Checking {inst:?}");
209
10000
                match inst {
210
5025
                    Instruction::String(string) => {
211
5025
                        debug_assert_eq!(
212
5025
                            reader.read_string().expect("Failed to read from stream"),
213
                            *string,
214
                            "Failed to read back the string"
215
                        )
216
                    }
217
4975
                    Instruction::Integer(value) => {
218
4975
                        debug_assert_eq!(
219
4975
                            reader.read_integer().expect("Failed to read from stream"),
220
                            *value,
221
                            "Failed to read back the integer"
222
                        )
223
                    }
224
                    Instruction::Bits(value, number_of_bits) => {
225
                        debug_assert_eq!(
226
                            reader.read_bits(*number_of_bits).expect("Failed to read from stream"),
227
                            *value,
228
                            "Failed to read back the bits"
229
                        )
230
                    }
231
                }
232
            }
233
100
        });
234
1
    }
235
}