sheetkit_core/
vba.rs

1//! VBA project extraction from macro-enabled workbooks (.xlsm).
2//!
3//! `.xlsm` files contain a `xl/vbaProject.bin` entry which is an OLE2
4//! Compound Binary File (CFB) holding VBA source code. This module
5//! provides read-only access to the raw binary and to individual VBA
6//! module source code.
7
8use std::io::{Cursor, Read as _};
9
10use crate::error::{Error, Result};
11
12/// Classification of a VBA module.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum VbaModuleType {
15    /// A standard code module (`.bas`).
16    Standard,
17    /// A class module (`.cls`).
18    Class,
19    /// A UserForm module.
20    Form,
21    /// A document module (e.g. Sheet code-behind).
22    Document,
23    /// The ThisWorkbook module.
24    ThisWorkbook,
25}
26
27/// A single VBA module with its name, source code, and type.
28#[derive(Debug, Clone)]
29pub struct VbaModule {
30    pub name: String,
31    pub source_code: String,
32    pub module_type: VbaModuleType,
33}
34
35/// Result of extracting a VBA project from a `.xlsm` file.
36///
37/// Contains extracted modules and any non-fatal warnings encountered
38/// during parsing (e.g., unreadable streams, decompression failures,
39/// unsupported codepages).
40#[derive(Debug, Clone)]
41pub struct VbaProject {
42    pub modules: Vec<VbaModule>,
43    pub warnings: Vec<String>,
44}
45
46/// Offset entry parsed from the `dir` stream for a single module.
47struct ModuleEntry {
48    name: String,
49    stream_name: String,
50    text_offset: u32,
51    module_type: VbaModuleType,
52}
53
54/// Parsed metadata from the `dir` stream.
55struct DirInfo {
56    entries: Vec<ModuleEntry>,
57    codepage: u16,
58}
59
60/// Extract VBA module source code from a `vbaProject.bin` binary blob.
61///
62/// Parses the OLE/CFB container, reads the `dir` stream to discover
63/// module metadata, then decompresses each module stream.
64///
65/// Returns a [`VbaProject`] containing extracted modules and any
66/// non-fatal warnings (e.g., modules that could not be read or
67/// decompressed, unsupported codepages).
68pub fn extract_vba_modules(vba_bin: &[u8]) -> Result<VbaProject> {
69    let cursor = Cursor::new(vba_bin);
70    let mut cfb = cfb::CompoundFile::open(cursor)
71        .map_err(|e| Error::Internal(format!("failed to open VBA project as CFB: {e}")))?;
72
73    // Find the VBA storage root. Typically `/VBA` or could be nested.
74    let vba_prefix = find_vba_prefix(&mut cfb)?;
75
76    // Read the `dir` stream to get module entries.
77    let dir_path = format!("{vba_prefix}dir");
78    let dir_data = read_cfb_stream(&mut cfb, &dir_path)?;
79
80    // The dir stream is compressed using MS-OVBA compression.
81    let decompressed_dir = decompress_vba_stream(&dir_data)?;
82
83    // Parse module entries and codepage from the decompressed dir stream.
84    let dir_info = parse_dir_stream(&decompressed_dir)?;
85
86    let mut modules = Vec::with_capacity(dir_info.entries.len());
87    let mut warnings = Vec::new();
88
89    for entry in dir_info.entries {
90        let stream_path = format!("{vba_prefix}{}", entry.stream_name);
91        let compressed_data = match read_cfb_stream(&mut cfb, &stream_path) {
92            Ok(data) => data,
93            Err(e) => {
94                warnings.push(format!(
95                    "skipped module '{}': failed to read stream '{}': {}",
96                    entry.name, stream_path, e
97                ));
98                continue;
99            }
100        };
101
102        // The module stream has text_offset bytes of "performance cache"
103        // (compiled code) followed by compressed source code.
104        if (entry.text_offset as usize) > compressed_data.len() {
105            warnings.push(format!(
106                "skipped module '{}': text_offset {} exceeds stream length {}",
107                entry.name,
108                entry.text_offset,
109                compressed_data.len()
110            ));
111            continue;
112        }
113        let source_compressed = &compressed_data[entry.text_offset as usize..];
114        let source_bytes = match decompress_vba_stream(source_compressed) {
115            Ok(b) => b,
116            Err(e) => {
117                warnings.push(format!(
118                    "skipped module '{}': decompression failed: {}",
119                    entry.name, e
120                ));
121                continue;
122            }
123        };
124
125        let source_code = decode_source_bytes(&source_bytes, dir_info.codepage, &mut warnings);
126
127        modules.push(VbaModule {
128            name: entry.name,
129            source_code,
130            module_type: entry.module_type,
131        });
132    }
133
134    Ok(VbaProject { modules, warnings })
135}
136
137/// Decode source bytes using the specified codepage.
138///
139/// Supports common codepages: 1252 (Western European), 932 (Japanese Shift-JIS),
140/// 949 (Korean), 936 (Simplified Chinese GBK), 65001 (UTF-8).
141/// For unrecognized codepages, falls back to UTF-8 lossy and emits a warning.
142fn decode_source_bytes(bytes: &[u8], codepage: u16, warnings: &mut Vec<String>) -> String {
143    match codepage {
144        65001 | 0 => String::from_utf8_lossy(bytes).into_owned(),
145        1252 => decode_single_byte(bytes, &WINDOWS_1252_HIGH),
146        932 => decode_shift_jis(bytes),
147        949 => decode_euc_kr(bytes),
148        936 => decode_gbk(bytes),
149        _ => {
150            warnings.push(format!(
151                "unsupported codepage {codepage}, falling back to UTF-8 lossy"
152            ));
153            String::from_utf8_lossy(bytes).into_owned()
154        }
155    }
156}
157
158/// Windows-1252 high-byte mapping (0x80..0xFF).
159/// Bytes 0x00..0x7F are identical to ASCII.
160static WINDOWS_1252_HIGH: [char; 128] = [
161    '\u{20AC}', '\u{0081}', '\u{201A}', '\u{0192}', '\u{201E}', '\u{2026}', '\u{2020}', '\u{2021}',
162    '\u{02C6}', '\u{2030}', '\u{0160}', '\u{2039}', '\u{0152}', '\u{008D}', '\u{017D}', '\u{008F}',
163    '\u{0090}', '\u{2018}', '\u{2019}', '\u{201C}', '\u{201D}', '\u{2022}', '\u{2013}', '\u{2014}',
164    '\u{02DC}', '\u{2122}', '\u{0161}', '\u{203A}', '\u{0153}', '\u{009D}', '\u{017E}', '\u{0178}',
165    '\u{00A0}', '\u{00A1}', '\u{00A2}', '\u{00A3}', '\u{00A4}', '\u{00A5}', '\u{00A6}', '\u{00A7}',
166    '\u{00A8}', '\u{00A9}', '\u{00AA}', '\u{00AB}', '\u{00AC}', '\u{00AD}', '\u{00AE}', '\u{00AF}',
167    '\u{00B0}', '\u{00B1}', '\u{00B2}', '\u{00B3}', '\u{00B4}', '\u{00B5}', '\u{00B6}', '\u{00B7}',
168    '\u{00B8}', '\u{00B9}', '\u{00BA}', '\u{00BB}', '\u{00BC}', '\u{00BD}', '\u{00BE}', '\u{00BF}',
169    '\u{00C0}', '\u{00C1}', '\u{00C2}', '\u{00C3}', '\u{00C4}', '\u{00C5}', '\u{00C6}', '\u{00C7}',
170    '\u{00C8}', '\u{00C9}', '\u{00CA}', '\u{00CB}', '\u{00CC}', '\u{00CD}', '\u{00CE}', '\u{00CF}',
171    '\u{00D0}', '\u{00D1}', '\u{00D2}', '\u{00D3}', '\u{00D4}', '\u{00D5}', '\u{00D6}', '\u{00D7}',
172    '\u{00D8}', '\u{00D9}', '\u{00DA}', '\u{00DB}', '\u{00DC}', '\u{00DD}', '\u{00DE}', '\u{00DF}',
173    '\u{00E0}', '\u{00E1}', '\u{00E2}', '\u{00E3}', '\u{00E4}', '\u{00E5}', '\u{00E6}', '\u{00E7}',
174    '\u{00E8}', '\u{00E9}', '\u{00EA}', '\u{00EB}', '\u{00EC}', '\u{00ED}', '\u{00EE}', '\u{00EF}',
175    '\u{00F0}', '\u{00F1}', '\u{00F2}', '\u{00F3}', '\u{00F4}', '\u{00F5}', '\u{00F6}', '\u{00F7}',
176    '\u{00F8}', '\u{00F9}', '\u{00FA}', '\u{00FB}', '\u{00FC}', '\u{00FD}', '\u{00FE}', '\u{00FF}',
177];
178
179/// Decode bytes using a single-byte codepage with the given high-byte table.
180fn decode_single_byte(bytes: &[u8], high_table: &[char; 128]) -> String {
181    let mut out = String::with_capacity(bytes.len());
182    for &b in bytes {
183        if b < 0x80 {
184            out.push(b as char);
185        } else {
186            out.push(high_table[(b - 0x80) as usize]);
187        }
188    }
189    out
190}
191
192/// Decode Shift-JIS (codepage 932) bytes to a String.
193/// Uses a best-effort approach: valid multi-byte sequences are decoded,
194/// invalid bytes are replaced with the Unicode replacement character.
195fn decode_shift_jis(bytes: &[u8]) -> String {
196    let mut out = String::with_capacity(bytes.len());
197    let mut i = 0;
198    while i < bytes.len() {
199        let b = bytes[i];
200        if b < 0x80 {
201            out.push(b as char);
202            i += 1;
203        } else if b == 0x80 || b == 0xA0 || b >= 0xFD {
204            out.push('\u{FFFD}');
205            i += 1;
206        } else if (0xA1..=0xDF).contains(&b) {
207            // Half-width katakana
208            out.push(char::from_u32(0xFF61 + (b as u32 - 0xA1)).unwrap_or('\u{FFFD}'));
209            i += 1;
210        } else if i + 1 < bytes.len() {
211            // Double-byte character -- fall back to replacement for simplicity
212            // Full Shift-JIS decoding requires a large mapping table.
213            out.push('\u{FFFD}');
214            i += 2;
215        } else {
216            out.push('\u{FFFD}');
217            i += 1;
218        }
219    }
220    out
221}
222
223/// Decode EUC-KR / codepage 949 bytes to a String.
224/// Best-effort: ASCII bytes pass through, multi-byte sequences use replacement.
225fn decode_euc_kr(bytes: &[u8]) -> String {
226    let mut out = String::with_capacity(bytes.len());
227    let mut i = 0;
228    while i < bytes.len() {
229        let b = bytes[i];
230        if b < 0x80 {
231            out.push(b as char);
232            i += 1;
233        } else if i + 1 < bytes.len() {
234            out.push('\u{FFFD}');
235            i += 2;
236        } else {
237            out.push('\u{FFFD}');
238            i += 1;
239        }
240    }
241    out
242}
243
244/// Decode GBK / codepage 936 bytes to a String.
245/// Best-effort: ASCII bytes pass through, multi-byte sequences use replacement.
246fn decode_gbk(bytes: &[u8]) -> String {
247    let mut out = String::with_capacity(bytes.len());
248    let mut i = 0;
249    while i < bytes.len() {
250        let b = bytes[i];
251        if b < 0x80 {
252            out.push(b as char);
253            i += 1;
254        } else if i + 1 < bytes.len() {
255            out.push('\u{FFFD}');
256            i += 2;
257        } else {
258            out.push('\u{FFFD}');
259            i += 1;
260        }
261    }
262    out
263}
264
265/// Find the VBA storage prefix inside the CFB container.
266/// Returns the path prefix ending with a separator (e.g. "VBA/").
267fn find_vba_prefix(cfb: &mut cfb::CompoundFile<Cursor<&[u8]>>) -> Result<String> {
268    // Collect all entries first to avoid borrow issues.
269    let entries: Vec<String> = cfb
270        .walk()
271        .map(|e| e.path().to_string_lossy().into_owned())
272        .collect();
273
274    // Look for a "dir" stream under a VBA storage.
275    for entry_path in &entries {
276        let normalized = entry_path.replace('\\', "/");
277        if normalized.ends_with("/dir") || normalized.ends_with("/DIR") {
278            let prefix = &normalized[..normalized.len() - 3];
279            return Ok(prefix.to_string());
280        }
281    }
282
283    // Try common paths directly.
284    for prefix in ["/VBA/", "VBA/", "/"] {
285        let dir_path = format!("{prefix}dir");
286        if cfb.is_stream(&dir_path) {
287            return Ok(prefix.to_string());
288        }
289    }
290
291    Err(Error::Internal(
292        "could not find VBA dir stream in vbaProject.bin".to_string(),
293    ))
294}
295
296/// Read a stream from the CFB container as raw bytes.
297fn read_cfb_stream(cfb: &mut cfb::CompoundFile<Cursor<&[u8]>>, path: &str) -> Result<Vec<u8>> {
298    let mut stream = cfb
299        .open_stream(path)
300        .map_err(|e| Error::Internal(format!("failed to open CFB stream '{path}': {e}")))?;
301    let mut data = Vec::new();
302    stream
303        .read_to_end(&mut data)
304        .map_err(|e| Error::Internal(format!("failed to read CFB stream '{path}': {e}")))?;
305    Ok(data)
306}
307
308/// Decompress a VBA compressed stream per MS-OVBA 2.4.1.
309///
310/// The format is:
311/// - 1 byte signature (0x01)
312/// - Sequence of compressed chunks, each starting with a 2-byte header
313/// - Each chunk contains a mix of literal bytes and copy tokens
314pub fn decompress_vba_stream(data: &[u8]) -> Result<Vec<u8>> {
315    if data.is_empty() {
316        return Ok(Vec::new());
317    }
318
319    if data[0] != 0x01 {
320        return Err(Error::Internal(format!(
321            "invalid VBA compression signature: expected 0x01, got 0x{:02X}",
322            data[0]
323        )));
324    }
325
326    let mut output = Vec::with_capacity(data.len() * 2);
327    let mut pos = 1; // skip signature byte
328
329    while pos < data.len() {
330        if pos + 1 >= data.len() {
331            break;
332        }
333
334        // Read chunk header (2 bytes, little-endian)
335        let header = u16::from_le_bytes([data[pos], data[pos + 1]]);
336        pos += 2;
337
338        let chunk_size = (header & 0x0FFF) as usize + 3;
339        let is_compressed = (header & 0x8000) != 0;
340
341        let chunk_end = (pos + chunk_size - 2).min(data.len());
342
343        if !is_compressed {
344            // Uncompressed chunk: raw bytes (4096 bytes max)
345            let raw_end = chunk_end.min(pos + 4096);
346            if raw_end > data.len() {
347                break;
348            }
349            output.extend_from_slice(&data[pos..raw_end]);
350            pos = chunk_end;
351            continue;
352        }
353
354        // Compressed chunk
355        let chunk_start_output = output.len();
356        while pos < chunk_end {
357            if pos >= data.len() {
358                break;
359            }
360
361            let flag_byte = data[pos];
362            pos += 1;
363
364            for bit_index in 0..8 {
365                if pos >= chunk_end {
366                    break;
367                }
368
369                if (flag_byte >> bit_index) & 1 == 0 {
370                    // Literal byte
371                    output.push(data[pos]);
372                    pos += 1;
373                } else {
374                    // Copy token (2 bytes, little-endian)
375                    if pos + 1 >= data.len() {
376                        pos = chunk_end;
377                        break;
378                    }
379                    let token = u16::from_le_bytes([data[pos], data[pos + 1]]);
380                    pos += 2;
381
382                    // Calculate the number of bits for the length and offset
383                    let decompressed_current = output.len() - chunk_start_output;
384                    let bit_count = max_bit_count(decompressed_current);
385                    let length_mask = 0xFFFF >> bit_count;
386                    let offset_mask = !length_mask;
387
388                    let length = ((token & length_mask) + 3) as usize;
389                    let offset = (((token & offset_mask) >> (16 - bit_count)) + 1) as usize;
390
391                    if offset > output.len() {
392                        // Invalid offset, skip
393                        break;
394                    }
395
396                    let copy_start = output.len() - offset;
397                    for i in 0..length {
398                        let byte = output[copy_start + (i % offset)];
399                        output.push(byte);
400                    }
401                }
402            }
403        }
404    }
405
406    Ok(output)
407}
408
409/// Calculate the bit count for the copy token offset field.
410/// Per MS-OVBA 2.4.1.3.19.1:
411/// The number of bits used for the offset is ceil(log2(decompressed_current)) with min 4.
412fn max_bit_count(decompressed_current: usize) -> u16 {
413    if decompressed_current <= 16 {
414        return 12;
415    }
416    if decompressed_current <= 32 {
417        return 11;
418    }
419    if decompressed_current <= 64 {
420        return 10;
421    }
422    if decompressed_current <= 128 {
423        return 9;
424    }
425    if decompressed_current <= 256 {
426        return 8;
427    }
428    if decompressed_current <= 512 {
429        return 7;
430    }
431    if decompressed_current <= 1024 {
432        return 6;
433    }
434    if decompressed_current <= 2048 {
435        return 5;
436    }
437    4 // >= 4096
438}
439
440/// Parse the decompressed `dir` stream to extract module entries and codepage.
441///
442/// The dir stream is a sequence of records with 2-byte IDs and 4-byte sizes.
443/// We look for MODULE_NAME, MODULE_STREAM_NAME, MODULE_OFFSET,
444/// MODULE_TYPE, and PROJECTCODEPAGE records.
445///
446/// MODULE_TYPE record 0x0021 indicates a procedural (standard) module.
447/// MODULE_TYPE record 0x0022 indicates a document/class module. When 0x0022
448/// is present, we refine the type to `Document`, `ThisWorkbook`, or `Class`
449/// based on the module name (since OOXML does not distinguish these subtypes
450/// at the record level).
451fn parse_dir_stream(data: &[u8]) -> Result<DirInfo> {
452    let mut pos = 0;
453    let mut modules = Vec::new();
454    let mut codepage: u16 = 1252; // Default to Windows-1252
455
456    // Current module being built
457    let mut current_name: Option<String> = None;
458    let mut current_stream_name: Option<String> = None;
459    let mut current_offset: u32 = 0;
460    let mut current_type = VbaModuleType::Standard;
461    let mut in_module = false;
462
463    while pos + 6 <= data.len() {
464        let record_id = u16::from_le_bytes([data[pos], data[pos + 1]]);
465        let record_size =
466            u32::from_le_bytes([data[pos + 2], data[pos + 3], data[pos + 4], data[pos + 5]])
467                as usize;
468        pos += 6;
469
470        if pos + record_size > data.len() {
471            break;
472        }
473
474        let record_data = &data[pos..pos + record_size];
475
476        match record_id {
477            // PROJECTCODEPAGE
478            0x0003 => {
479                if record_size >= 2 {
480                    codepage = u16::from_le_bytes([record_data[0], record_data[1]]);
481                }
482            }
483            // MODULENAME
484            0x0019 => {
485                if in_module {
486                    // Save previous module
487                    if let (Some(name), Some(stream)) =
488                        (current_name.take(), current_stream_name.take())
489                    {
490                        let refined_type = refine_module_type(&current_type, &name);
491                        modules.push(ModuleEntry {
492                            name,
493                            stream_name: stream,
494                            text_offset: current_offset,
495                            module_type: refined_type,
496                        });
497                    }
498                }
499                in_module = true;
500                current_name = Some(String::from_utf8_lossy(record_data).into_owned());
501                current_stream_name = None;
502                current_offset = 0;
503                current_type = VbaModuleType::Standard;
504            }
505            // MODULENAMEUNICODE
506            0x0047 => {
507                // UTF-16LE encoded name, prefer this over the ANSI name.
508                // Use only the even portion of the data; an odd trailing
509                // byte indicates a truncated record and is safely ignored.
510                if record_size >= 2 {
511                    let even_len = record_data.len() & !1;
512                    let u16_data: Vec<u16> = record_data[..even_len]
513                        .chunks_exact(2)
514                        .map(|c| u16::from_le_bytes([c[0], c[1]]))
515                        .collect();
516                    let name = String::from_utf16_lossy(&u16_data);
517                    // Remove trailing null if present
518                    let name = name.trim_end_matches('\0').to_string();
519                    if !name.is_empty() {
520                        current_name = Some(name);
521                    }
522                }
523            }
524            // MODULESTREAMNAME
525            0x001A => {
526                current_stream_name = Some(String::from_utf8_lossy(record_data).into_owned());
527                // The MODULENAMEUNICODE record for stream name follows with id 0x0032
528                // We handle it inline: skip the unicode record
529                if pos + record_size + 6 <= data.len() {
530                    let next_id =
531                        u16::from_le_bytes([data[pos + record_size], data[pos + record_size + 1]]);
532                    if next_id == 0x0032 {
533                        let next_size = u32::from_le_bytes([
534                            data[pos + record_size + 2],
535                            data[pos + record_size + 3],
536                            data[pos + record_size + 4],
537                            data[pos + record_size + 5],
538                        ]) as usize;
539                        // Skip the unicode stream name record
540                        pos += record_size + 6 + next_size;
541                        continue;
542                    }
543                }
544            }
545            // MODULEOFFSET
546            0x0031 => {
547                if record_size >= 4 {
548                    current_offset = u32::from_le_bytes([
549                        record_data[0],
550                        record_data[1],
551                        record_data[2],
552                        record_data[3],
553                    ]);
554                }
555            }
556            // MODULETYPE procedural (0x0021)
557            0x0021 => {
558                current_type = VbaModuleType::Standard;
559            }
560            // MODULETYPE document/class (0x0022)
561            0x0022 => {
562                // The dir stream only distinguishes procedural (0x0021) from
563                // non-procedural (0x0022). We refine 0x0022 into Document,
564                // ThisWorkbook, or Class based on the module name when the
565                // module is finalized.
566                current_type = VbaModuleType::Class;
567            }
568            // TERMINATOR for modules section (0x002B)
569            0x002B => {
570                // End of module list
571            }
572            _ => {}
573        }
574
575        pos += record_size;
576    }
577
578    // Save the last module if present
579    if in_module {
580        if let (Some(name), Some(stream)) = (current_name, current_stream_name) {
581            let refined_type = refine_module_type(&current_type, &name);
582            modules.push(ModuleEntry {
583                name,
584                stream_name: stream,
585                text_offset: current_offset,
586                module_type: refined_type,
587            });
588        }
589    }
590
591    Ok(DirInfo {
592        entries: modules,
593        codepage,
594    })
595}
596
597/// Refine the module type for non-procedural modules (0x0022) based on
598/// the module name. Procedural modules (0x0021) are always `Standard`.
599fn refine_module_type(base_type: &VbaModuleType, name: &str) -> VbaModuleType {
600    if *base_type == VbaModuleType::Standard {
601        return VbaModuleType::Standard;
602    }
603    let name_lower = name.to_lowercase();
604    if name_lower == "thisworkbook" {
605        VbaModuleType::ThisWorkbook
606    } else if name_lower.starts_with("sheet") {
607        VbaModuleType::Document
608    } else {
609        // Remains as Class (could be a class module or UserForm).
610        VbaModuleType::Class
611    }
612}
613
614#[cfg(test)]
615#[allow(clippy::same_item_push)]
616mod tests {
617    use super::*;
618
619    #[test]
620    fn test_decompress_empty_input() {
621        let result = decompress_vba_stream(&[]);
622        assert!(result.is_ok());
623        assert!(result.unwrap().is_empty());
624    }
625
626    #[test]
627    fn test_decompress_invalid_signature() {
628        let result = decompress_vba_stream(&[0x00, 0x01, 0x02]);
629        assert!(result.is_err());
630        let err_msg = result.unwrap_err().to_string();
631        assert!(err_msg.contains("invalid VBA compression signature"));
632    }
633
634    #[test]
635    fn test_decompress_uncompressed_chunk() {
636        // Signature byte + uncompressed chunk header (size=3 -> 3+3-2=4 bytes, bit 15 clear)
637        // Header: chunk_size = 3 bytes (field = 3-3 = 0), not compressed (bit 15 = 0)
638        // So header = 0x0000 means size=3, uncompressed
639        let mut data = vec![0x01]; // signature
640                                   // Uncompressed chunk: header with bit 15 clear, size field = N-3
641                                   // For 4 bytes of data: chunk_size = 4, field = 4-3 = 1
642        let header: u16 = 0x0001; // bit 15 = 0 (uncompressed), size = 1+3-2 = 2 (actual chunk payload = 2)
643                                  // Wait, let me recalculate.
644                                  // chunk_size = (header & 0x0FFF) + 3 = field + 3
645                                  // The chunk payload is chunk_size - 2 = field + 1 bytes
646                                  // For 3 bytes of payload: field = 2, header = 0x0002
647        data.extend_from_slice(&header.to_le_bytes());
648        data.extend_from_slice(b"AB");
649        // This should produce "AB" but limited to min(chunk_end, pos+4096)
650        let result = decompress_vba_stream(&data).unwrap();
651        assert_eq!(&result, b"AB");
652    }
653
654    #[test]
655    fn test_decompress_real_compressed_data() {
656        // Test with a known compressed sequence from the MS-OVBA spec example.
657        // Compressed representation of "aaaaaaaaaaaaaaa" (15 'a's)
658        // Signature: 0x01
659        // Chunk header: compressed, size field
660        // Flag byte: 0b00000011 = 0x03 (bit 0: literal, bit 1: copy token)
661        // Actually building a minimal valid compressed stream:
662        // Signature: 0x01
663        // Chunk header: size = N-3, compressed bit set
664        // Then flag + data
665
666        // A simpler approach: verify that decompression of a manually built stream works.
667        let mut compressed = vec![0x01u8];
668        // Build chunk: 1 literal 'a', then copy token referencing offset=1, length=3
669        // Flag byte: 0b00000010 = bit 0 literal, bit 1 copy
670        // Literal: b'a'
671        // Copy token with decompressed_current=1 -> bit_count=12, length_mask=0x000F
672        // offset=1, length=3 -> offset_field=(1-1)<<4=0, length_field=3-3=0
673        // token = 0x0000
674        let flag = 0x02u8; // bits: 0=literal, 1=copy, rest=0
675        let literal = b'a';
676        let copy_token: u16 = 0x0000; // offset=1, length=3
677
678        let mut chunk_payload = Vec::new();
679        chunk_payload.push(flag);
680        chunk_payload.push(literal);
681        chunk_payload.extend_from_slice(&copy_token.to_le_bytes());
682
683        let chunk_size = chunk_payload.len() + 2; // +2 for header
684        let header: u16 = 0x8000 | ((chunk_size as u16 - 3) & 0x0FFF);
685        compressed.extend_from_slice(&header.to_le_bytes());
686        compressed.extend_from_slice(&chunk_payload);
687
688        let result = decompress_vba_stream(&compressed).unwrap();
689        assert_eq!(&result, b"aaaa"); // 1 literal + 3 from copy
690    }
691
692    #[test]
693    fn test_max_bit_count() {
694        assert_eq!(max_bit_count(0), 12);
695        assert_eq!(max_bit_count(1), 12);
696        assert_eq!(max_bit_count(16), 12);
697        assert_eq!(max_bit_count(17), 11);
698        assert_eq!(max_bit_count(32), 11);
699        assert_eq!(max_bit_count(33), 10);
700        assert_eq!(max_bit_count(64), 10);
701        assert_eq!(max_bit_count(65), 9);
702        assert_eq!(max_bit_count(128), 9);
703        assert_eq!(max_bit_count(129), 8);
704        assert_eq!(max_bit_count(256), 8);
705        assert_eq!(max_bit_count(257), 7);
706        assert_eq!(max_bit_count(512), 7);
707        assert_eq!(max_bit_count(513), 6);
708        assert_eq!(max_bit_count(1024), 6);
709        assert_eq!(max_bit_count(1025), 5);
710        assert_eq!(max_bit_count(2048), 5);
711        assert_eq!(max_bit_count(2049), 4);
712        assert_eq!(max_bit_count(4096), 4);
713    }
714
715    #[test]
716    fn test_parse_dir_stream_empty() {
717        let result = parse_dir_stream(&[]);
718        assert!(result.is_ok());
719        let info = result.unwrap();
720        assert!(info.entries.is_empty());
721        assert_eq!(info.codepage, 1252);
722    }
723
724    #[test]
725    fn test_extract_vba_modules_invalid_cfb() {
726        let result = extract_vba_modules(b"not a CFB file");
727        assert!(result.is_err());
728        let err_msg = result.unwrap_err().to_string();
729        assert!(err_msg.contains("failed to open VBA project as CFB"));
730    }
731
732    #[test]
733    fn test_vba_module_type_clone() {
734        let t = VbaModuleType::Standard;
735        let t2 = t.clone();
736        assert_eq!(t, t2);
737    }
738
739    #[test]
740    fn test_vba_module_debug() {
741        let m = VbaModule {
742            name: "Module1".to_string(),
743            source_code: "Sub Test()\nEnd Sub".to_string(),
744            module_type: VbaModuleType::Standard,
745        };
746        let debug = format!("{:?}", m);
747        assert!(debug.contains("Module1"));
748    }
749
750    #[test]
751    fn test_vba_roundtrip_with_xlsm() {
752        use std::io::{Read as _, Write as _};
753
754        // Build a minimal CFB container with a VBA dir stream and a module
755        let vba_bin = build_test_vba_project();
756
757        // Create a valid xlsx using the Workbook API, then inject vbaProject.bin
758        let base_wb = crate::workbook::Workbook::new();
759        let base_buf = base_wb.save_to_buffer().unwrap();
760
761        // Rewrite the ZIP, adding the vbaProject.bin entry
762        let mut buf = Vec::new();
763        {
764            let base_cursor = std::io::Cursor::new(&base_buf);
765            let mut base_archive = zip::ZipArchive::new(base_cursor).unwrap();
766
767            let out_cursor = std::io::Cursor::new(&mut buf);
768            let mut zip = zip::ZipWriter::new(out_cursor);
769            let options = zip::write::SimpleFileOptions::default()
770                .compression_method(zip::CompressionMethod::Deflated);
771
772            for i in 0..base_archive.len() {
773                let mut entry = base_archive.by_index(i).unwrap();
774                let name = entry.name().to_string();
775                zip.start_file(&name, options).unwrap();
776                let mut data = Vec::new();
777                entry.read_to_end(&mut data).unwrap();
778                zip.write_all(&data).unwrap();
779            }
780
781            zip.start_file("xl/vbaProject.bin", options).unwrap();
782            zip.write_all(&vba_bin).unwrap();
783            zip.finish().unwrap();
784        }
785
786        // Open and extract
787        let opts = crate::workbook::OpenOptions::new()
788            .read_mode(crate::workbook::ReadMode::Eager)
789            .aux_parts(crate::workbook::AuxParts::EagerLoad);
790        let wb = crate::workbook::Workbook::open_from_buffer_with_options(&buf, &opts).unwrap();
791
792        // Raw VBA project should be available
793        let raw = wb.get_vba_project();
794        assert!(raw.is_some(), "VBA project binary should be present");
795        assert_eq!(raw.unwrap(), vba_bin);
796    }
797
798    #[test]
799    fn test_xlsx_without_vba_returns_none() {
800        let wb = crate::workbook::Workbook::new();
801        assert!(wb.get_vba_project().is_none());
802        assert!(wb.get_vba_modules().unwrap().is_none());
803    }
804
805    #[test]
806    fn test_xlsx_roundtrip_no_vba() {
807        let wb = crate::workbook::Workbook::new();
808        let buf = wb.save_to_buffer().unwrap();
809        let wb2 = crate::workbook::Workbook::open_from_buffer(&buf).unwrap();
810        assert!(wb2.get_vba_project().is_none());
811    }
812
813    #[test]
814    fn test_get_vba_modules_from_test_project() {
815        use std::io::{Read as _, Write as _};
816
817        let vba_bin = build_test_vba_project();
818
819        // Create a valid xlsx, then inject vbaProject.bin
820        let base_wb = crate::workbook::Workbook::new();
821        let base_buf = base_wb.save_to_buffer().unwrap();
822
823        let mut buf = Vec::new();
824        {
825            let base_cursor = std::io::Cursor::new(&base_buf);
826            let mut base_archive = zip::ZipArchive::new(base_cursor).unwrap();
827
828            let out_cursor = std::io::Cursor::new(&mut buf);
829            let mut zip = zip::ZipWriter::new(out_cursor);
830            let options = zip::write::SimpleFileOptions::default()
831                .compression_method(zip::CompressionMethod::Deflated);
832
833            for i in 0..base_archive.len() {
834                let mut entry = base_archive.by_index(i).unwrap();
835                let name = entry.name().to_string();
836                zip.start_file(&name, options).unwrap();
837                let mut data = Vec::new();
838                entry.read_to_end(&mut data).unwrap();
839                zip.write_all(&data).unwrap();
840            }
841
842            zip.start_file("xl/vbaProject.bin", options).unwrap();
843            zip.write_all(&vba_bin).unwrap();
844            zip.finish().unwrap();
845        }
846
847        let opts = crate::workbook::OpenOptions::new()
848            .read_mode(crate::workbook::ReadMode::Eager)
849            .aux_parts(crate::workbook::AuxParts::EagerLoad);
850        let wb = crate::workbook::Workbook::open_from_buffer_with_options(&buf, &opts).unwrap();
851        let project = wb.get_vba_modules().unwrap();
852        assert!(project.is_some(), "should have VBA modules");
853        let project = project.unwrap();
854        assert_eq!(project.modules.len(), 1);
855        assert_eq!(project.modules[0].name, "Module1");
856        assert_eq!(project.modules[0].module_type, VbaModuleType::Standard);
857        assert!(
858            project.modules[0].source_code.contains("Sub Hello()"),
859            "source should contain Sub Hello(), got: {}",
860            project.modules[0].source_code
861        );
862    }
863
864    #[test]
865    fn test_vba_project_preserved_in_save_roundtrip() {
866        use std::io::{Read as _, Write as _};
867
868        let vba_bin = build_test_vba_project();
869
870        let base_wb = crate::workbook::Workbook::new();
871        let base_buf = base_wb.save_to_buffer().unwrap();
872
873        let mut buf = Vec::new();
874        {
875            let base_cursor = std::io::Cursor::new(&base_buf);
876            let mut base_archive = zip::ZipArchive::new(base_cursor).unwrap();
877
878            let out_cursor = std::io::Cursor::new(&mut buf);
879            let mut zip = zip::ZipWriter::new(out_cursor);
880            let options = zip::write::SimpleFileOptions::default()
881                .compression_method(zip::CompressionMethod::Deflated);
882
883            for i in 0..base_archive.len() {
884                let mut entry = base_archive.by_index(i).unwrap();
885                let name = entry.name().to_string();
886                zip.start_file(&name, options).unwrap();
887                let mut data = Vec::new();
888                entry.read_to_end(&mut data).unwrap();
889                zip.write_all(&data).unwrap();
890            }
891
892            zip.start_file("xl/vbaProject.bin", options).unwrap();
893            zip.write_all(&vba_bin).unwrap();
894            zip.finish().unwrap();
895        }
896
897        // Open, then save again
898        let opts = crate::workbook::OpenOptions::new()
899            .read_mode(crate::workbook::ReadMode::Eager)
900            .aux_parts(crate::workbook::AuxParts::EagerLoad);
901        let wb = crate::workbook::Workbook::open_from_buffer_with_options(&buf, &opts).unwrap();
902        let saved_buf = wb.save_to_buffer().unwrap();
903
904        // Re-open and verify VBA is preserved
905        let wb2 =
906            crate::workbook::Workbook::open_from_buffer_with_options(&saved_buf, &opts).unwrap();
907        let raw = wb2.get_vba_project();
908        assert!(raw.is_some(), "VBA project should survive save roundtrip");
909        assert_eq!(raw.unwrap(), vba_bin);
910
911        // Modules should still be extractable
912        let project = wb2.get_vba_modules().unwrap().unwrap();
913        assert_eq!(project.modules.len(), 1);
914        assert_eq!(project.modules[0].name, "Module1");
915    }
916
917    /// Build a minimal CFB container that looks like a VBA project.
918    fn build_test_vba_project() -> Vec<u8> {
919        let mut buf = Vec::new();
920        let cursor = std::io::Cursor::new(&mut buf);
921        let mut cfb = cfb::CompoundFile::create(cursor).unwrap();
922
923        // Create VBA storage
924        cfb.create_storage("/VBA").unwrap();
925
926        // Build a minimal dir stream
927        let dir_data = build_minimal_dir_stream("Module1");
928
929        // Compress the dir stream
930        let compressed_dir = compress_for_test(&dir_data);
931
932        // Write dir stream
933        {
934            let mut stream = cfb.create_stream("/VBA/dir").unwrap();
935            std::io::Write::write_all(&mut stream, &compressed_dir).unwrap();
936        }
937
938        // Build module source: "Sub Hello()\nEnd Sub\n"
939        let source = b"Sub Hello()\r\nEnd Sub\r\n";
940        let compressed_source = compress_for_test(source);
941
942        // The module stream has 0 bytes of performance cache + compressed source.
943        // (text_offset = 0 in the dir stream)
944        {
945            let mut stream = cfb.create_stream("/VBA/Module1").unwrap();
946            std::io::Write::write_all(&mut stream, &compressed_source).unwrap();
947        }
948
949        // Create _VBA_PROJECT stream (required for validity, can be minimal)
950        {
951            let mut stream = cfb.create_stream("/VBA/_VBA_PROJECT").unwrap();
952            // Minimal header: version bytes
953            let header = [0xCC, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00];
954            std::io::Write::write_all(&mut stream, &header).unwrap();
955        }
956
957        cfb.flush().unwrap();
958        buf
959    }
960
961    /// Build a minimal dir stream binary for one standard module.
962    fn build_minimal_dir_stream(module_name: &str) -> Vec<u8> {
963        let mut data = Vec::new();
964        let name_bytes = module_name.as_bytes();
965
966        // PROJECTSYSKIND record (0x0001): 4 bytes, value = 1 (Win32)
967        write_dir_record(&mut data, 0x0001, &1u32.to_le_bytes());
968
969        // PROJECTLCID record (0x0002): 4 bytes
970        write_dir_record(&mut data, 0x0002, &0x0409u32.to_le_bytes());
971
972        // PROJECTLCIDINVOKE record (0x0014): 4 bytes
973        write_dir_record(&mut data, 0x0014, &0x0409u32.to_le_bytes());
974
975        // PROJECTCODEPAGE record (0x0003): 2 bytes (1252 = Windows-1252)
976        write_dir_record(&mut data, 0x0003, &1252u16.to_le_bytes());
977
978        // PROJECTNAME record (0x0004)
979        write_dir_record(&mut data, 0x0004, b"VBAProject");
980
981        // PROJECTDOCSTRING record (0x0005): empty
982        write_dir_record(&mut data, 0x0005, &[]);
983        // Unicode variant (0x0040): empty
984        write_dir_record(&mut data, 0x0040, &[]);
985
986        // PROJECTHELPFILEPATH record (0x0006): empty
987        write_dir_record(&mut data, 0x0006, &[]);
988        // Unicode variant (0x003D): empty
989        write_dir_record(&mut data, 0x003D, &[]);
990
991        // PROJECTHELPCONTEXT (0x0007): 4 bytes
992        write_dir_record(&mut data, 0x0007, &0u32.to_le_bytes());
993
994        // PROJECTLIBFLAGS (0x0008): 4 bytes
995        write_dir_record(&mut data, 0x0008, &0u32.to_le_bytes());
996
997        // PROJECTVERSION (0x0009): 4 + 2 bytes (major + minor)
998        let mut version = Vec::new();
999        version.extend_from_slice(&1u32.to_le_bytes());
1000        version.extend_from_slice(&0u16.to_le_bytes());
1001        // Version record is special: id=0x0009, size=4 for major, then 2 bytes minor appended
1002        write_dir_record(&mut data, 0x0009, &version);
1003
1004        // PROJECTCONSTANTS (0x000C): empty
1005        write_dir_record(&mut data, 0x000C, &[]);
1006        // Unicode variant (0x003C): empty
1007        write_dir_record(&mut data, 0x003C, &[]);
1008
1009        // MODULES count record: id=0x000F, size=2
1010        let module_count: u16 = 1;
1011        write_dir_record(&mut data, 0x000F, &module_count.to_le_bytes());
1012
1013        // PROJECTCOOKIE record (0x0013): 2 bytes
1014        write_dir_record(&mut data, 0x0013, &0u16.to_le_bytes());
1015
1016        // MODULE_NAME record (0x0019)
1017        write_dir_record(&mut data, 0x0019, name_bytes);
1018
1019        // MODULE_STREAM_NAME record (0x001A)
1020        write_dir_record(&mut data, 0x001A, name_bytes);
1021        // Unicode variant (0x0032)
1022        let name_utf16: Vec<u8> = module_name
1023            .encode_utf16()
1024            .flat_map(|c| c.to_le_bytes())
1025            .collect();
1026        write_dir_record(&mut data, 0x0032, &name_utf16);
1027
1028        // MODULE_OFFSET record (0x0031): 4 bytes (offset = 0)
1029        write_dir_record(&mut data, 0x0031, &0u32.to_le_bytes());
1030
1031        // MODULE_TYPE procedural (0x0021): 0 bytes
1032        write_dir_record(&mut data, 0x0021, &[]);
1033
1034        // MODULE_TERMINATOR (0x002B): 0 bytes
1035        write_dir_record(&mut data, 0x002B, &[]);
1036
1037        // End of modules
1038        // Global TERMINATOR (0x0010): 0 bytes
1039        write_dir_record(&mut data, 0x0010, &[]);
1040
1041        data
1042    }
1043
1044    fn write_dir_record(buf: &mut Vec<u8>, id: u16, data: &[u8]) {
1045        buf.extend_from_slice(&id.to_le_bytes());
1046        buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
1047        buf.extend_from_slice(data);
1048    }
1049
1050    /// Minimal MS-OVBA "compression" that produces an uncompressed container.
1051    /// Signature 0x01 + one uncompressed chunk per 4096 bytes.
1052    fn compress_for_test(data: &[u8]) -> Vec<u8> {
1053        let mut result = vec![0x01u8]; // signature
1054        let mut pos = 0;
1055        while pos < data.len() {
1056            let chunk_len = (data.len() - pos).min(4096);
1057            let chunk_data = &data[pos..pos + chunk_len];
1058            // Chunk header: bit 15 = 0 (uncompressed), bits 0-11 = chunk_len + 2 - 3
1059            let header: u16 = (chunk_len as u16 + 2).wrapping_sub(3) & 0x0FFF;
1060            result.extend_from_slice(&header.to_le_bytes());
1061            result.extend_from_slice(chunk_data);
1062            // Pad to 4096 if needed
1063            for _ in chunk_len..4096 {
1064                result.push(0x00);
1065            }
1066            pos += chunk_len;
1067        }
1068        result
1069    }
1070}