1
use std::io::Read;
2
use std::io::Write;
3
use std::io::{self};
4

            
5
use bitstream_io::BigEndian;
6
use bitstream_io::BitRead;
7
use bitstream_io::BitReader;
8
use bitstream_io::BitWrite;
9
use bitstream_io::BitWriter;
10
use log::error;
11

            
12
use merc_number::read_u64_variablelength;
13
use merc_number::write_u64_variablelength;
14
use merc_utilities::MercError;
15

            
16
/// Trait for writing bit-level data.
17
pub trait BitStreamWrite {
18
    /// Writes the least significant bits from a u64 value. The `number_of_bits` must be <= 64.
19
    fn write_bits(&mut self, value: u64, number_of_bits: u8) -> Result<(), MercError>;
20

            
21
    /// Writes a string prefixed with its length as a variable-width integer.
22
    fn write_string(&mut self, s: &str) -> Result<(), MercError>;
23

            
24
    /// Writes a u64 value using variable-width encoding.
25
    fn write_integer(&mut self, value: u64) -> Result<(), MercError>;
26

            
27
    /// Flushes any remaining bits to the underlying writer.
28
    fn flush(&mut self) -> Result<(), MercError>;
29
}
30

            
31
/// Trait for reading bit-level data.
32
pub trait BitStreamRead {
33
    /// Reads bits into the least significant bits of a u64. The `number_of_bits` must be <= 64.
34
    fn read_bits(&mut self, number_of_bits: u8) -> Result<u64, MercError>;
35

            
36
    /// Reads a length-prefixed string.
37
    fn read_string(&mut self) -> Result<String, MercError>;
38

            
39
    /// Reads a variable-width encoded integer.
40
    fn read_integer(&mut self) -> Result<u64, MercError>;
41
}
42

            
43
/// Writer for bit-level output operations using an underlying writer.
44
pub struct BitStreamWriter<W: Write> {
45
    writer: BitWriter<W, BigEndian>,
46
}
47

            
48
impl<W: Write> BitStreamWriter<W> {
49
    /// Creates a new BitStreamWriter wrapping the provided writer.
50
500
    pub fn new(writer: W) -> Self {
51
500
        Self {
52
500
            writer: BitWriter::new(writer),
53
500
        }
54
500
    }
55
}
56

            
57
impl<W: Write> Drop for BitStreamWriter<W> {
58
500
    fn drop(&mut self) {
59
500
        if self.flush().is_err() {
60
            error!("Panicked while flushing the stream when dropped!")
61
500
        }
62
500
    }
63
}
64

            
65
/// Reader for bit-level input operations from an underlying reader.
66
pub struct BitStreamReader<R: Read> {
67
    reader: BitReader<R, BigEndian>,
68
    text_buffer: Vec<u8>,
69
}
70

            
71
impl<R: Read> BitStreamReader<R> {
72
    /// Creates a new BitStreamReader wrapping the provided reader.
73
504
    pub fn new(reader: R) -> Self {
74
504
        Self {
75
504
            reader: BitReader::new(reader),
76
504
            text_buffer: Vec::with_capacity(128),
77
504
        }
78
504
    }
79
}
80

            
81
impl<W: Write> BitStreamWrite for BitStreamWriter<W> {
82
2119386
    fn write_bits(&mut self, value: u64, number_of_bits: u8) -> Result<(), MercError> {
83
2119386
        debug_assert!(number_of_bits <= 64);
84
2119386
        Ok(self.writer.write_var(number_of_bits as u32, value)?)
85
2119386
    }
86

            
87
7175
    fn write_string(&mut self, s: &str) -> Result<(), MercError> {
88
7175
        self.write_integer(s.len() as u64)?;
89
50025
        for byte in s.as_bytes() {
90
50025
            self.writer.write::<8, u64>(*byte as u64)?;
91
        }
92
7175
        Ok(())
93
7175
    }
94

            
95
650298
    fn write_integer(&mut self, value: u64) -> Result<(), MercError> {
96
650298
        write_u64_variablelength(&mut self.writer, value)?;
97
650298
        Ok(())
98
650298
    }
99

            
100
900
    fn flush(&mut self) -> Result<(), MercError> {
101
900
        self.writer.byte_align()?;
102
900
        Ok(self.writer.flush()?)
103
900
    }
104
}
105

            
106
impl<R: Read> BitStreamRead for BitStreamReader<R> {
107
2163128
    fn read_bits(&mut self, number_of_bits: u8) -> Result<u64, MercError> {
108
2163128
        assert!(number_of_bits <= 64);
109
2163128
        Ok(self.reader.read_var(number_of_bits as u32)?)
110
2163128
    }
111

            
112
7481
    fn read_string(&mut self) -> Result<String, MercError> {
113
7481
        let length = self.read_integer()?;
114
7481
        self.text_buffer.clear();
115
7481
        self.text_buffer
116
7481
            .reserve(length.try_into().expect("String size exceeds usize!"));
117

            
118
7481
        for _ in 0..length {
119
53560
            let byte = self.reader.read::<8, u8>()?;
120
53560
            self.text_buffer.push(byte);
121
        }
122

            
123
7481
        Ok(String::from_utf8(self.text_buffer.clone()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?)
124
7481
    }
125

            
126
661646
    fn read_integer(&mut self) -> Result<u64, MercError> {
127
661646
        read_u64_variablelength(&mut self.reader)
128
661646
    }
129
}
130

            
131
#[cfg(test)]
132
mod tests {
133
    use super::*;
134

            
135
    use log::debug;
136
    use rand::RngExt;
137
    use rand::distr::Alphanumeric;
138

            
139
    use merc_utilities::random_test;
140

            
141
    /// Decide (arbitrarily) what to write into the bitstream.
142
    #[derive(Debug)]
143
    enum Instruction {
144
        String(String),
145
        Integer(u64),
146
        /// (value, num_of_bits), where num_of_bits must be at most 64.
147
        Bits(u64, u8),
148
    }
149

            
150
    /// Calculate minimum bits needed to represent the value
151
    /// Use 1 bit if value is 0 to ensure at least 1 bit is written
152
    pub fn required_bits(value: u64) -> u8 {
153
        if value == 0 {
154
            1
155
        } else {
156
            64 - value.leading_zeros() as u8
157
        }
158
    }
159

            
160
    #[test]
161
1
    fn test_arbitrary_bitstream() {
162
100
        random_test(100, |rng| {
163
100
            let instructions: Vec<Instruction> = (0..100)
164
10000
                .map(|_| match rng.random_range(0..2) {
165
                    0 => {
166
4975
                        let string = rng.sample_iter(&Alphanumeric).take(7).map(char::from).collect();
167
4975
                        Instruction::String(string)
168
                    }
169
5025
                    1 => Instruction::Integer(rng.random()),
170
                    2 => {
171
                        let value: u64 = rng.random();
172
                        Instruction::Bits(value, required_bits(value))
173
                    }
174
                    _ => unreachable!("The range is from 0 to 2"),
175
10000
                })
176
100
                .collect();
177

            
178
100
            let mut buffer = Vec::new();
179
            {
180
100
                let mut writer = BitStreamWriter::new(&mut buffer);
181

            
182
10000
                for inst in &instructions {
183
10000
                    debug!("Writing {inst:?}");
184
10000
                    match inst {
185
4975
                        Instruction::String(string) => {
186
4975
                            writer.write_string(string).expect("Failed to write into stream")
187
                        }
188
5025
                        Instruction::Integer(value) => {
189
5025
                            writer.write_integer(*value).expect("Failed to write into stream")
190
                        }
191
                        Instruction::Bits(value, number_of_bits) => writer
192
                            .write_bits(*value, *number_of_bits)
193
                            .expect("Failed to write into stream"),
194
                    }
195
                }
196

            
197
100
                writer.flush().expect("Failed to write into stream");
198
            }
199

            
200
100
            let mut reader = BitStreamReader::new(&buffer[..]);
201

            
202
10000
            for inst in &instructions {
203
10000
                debug!("Checking {inst:?}");
204
10000
                match inst {
205
4975
                    Instruction::String(string) => {
206
4975
                        debug_assert_eq!(
207
4975
                            reader.read_string().expect("Failed to read from stream"),
208
                            *string,
209
                            "Failed to read back the string"
210
                        )
211
                    }
212
5025
                    Instruction::Integer(value) => {
213
5025
                        debug_assert_eq!(
214
5025
                            reader.read_integer().expect("Failed to read from stream"),
215
                            *value,
216
                            "Failed to read back the integer"
217
                        )
218
                    }
219
                    Instruction::Bits(value, number_of_bits) => {
220
                        debug_assert_eq!(
221
                            reader.read_bits(*number_of_bits).expect("Failed to read from stream"),
222
                            *value,
223
                            "Failed to read back the bits"
224
                        )
225
                    }
226
                }
227
            }
228
100
        });
229
1
    }
230
}