package wal import ( "bufio" "encoding/binary" "fmt" "hash/crc32" "io" "os" "path/filepath" "sort" ) // Reader reads entries from WAL files type Reader struct { file *os.File reader *bufio.Reader buffer []byte fragments [][]byte currType uint8 } // OpenReader creates a new Reader for the given WAL file func OpenReader(path string) (*Reader, error) { file, err := os.Open(path) if err != nil { return nil, fmt.Errorf("failed to open WAL file: %w", err) } return &Reader{ file: file, reader: bufio.NewReaderSize(file, 64*1024), // 64KB buffer buffer: make([]byte, MaxRecordSize), fragments: make([][]byte, 0), }, nil } // ReadEntry reads the next entry from the WAL func (r *Reader) ReadEntry() (*Entry, error) { // Loop until we have a complete entry for { // Read a record record, err := r.readRecord() if err != nil { if err == io.EOF { // If we have fragments, this is unexpected EOF if len(r.fragments) > 0 { return nil, fmt.Errorf("unexpected EOF with %d fragments", len(r.fragments)) } return nil, io.EOF } return nil, err } // Process based on record type switch record.recordType { case RecordTypeFull: // Single record, parse directly return r.parseEntryData(record.data) case RecordTypeFirst: // Start of a fragmented entry r.fragments = append(r.fragments, record.data) r.currType = record.data[0] // Save the operation type case RecordTypeMiddle: // Middle fragment if len(r.fragments) == 0 { return nil, fmt.Errorf("%w: middle fragment without first fragment", ErrCorruptRecord) } r.fragments = append(r.fragments, record.data) case RecordTypeLast: // Last fragment if len(r.fragments) == 0 { return nil, fmt.Errorf("%w: last fragment without previous fragments", ErrCorruptRecord) } r.fragments = append(r.fragments, record.data) // Combine fragments into a single entry entry, err := r.processFragments() if err != nil { return nil, err } return entry, nil default: return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, record.recordType) } } } // Record represents a physical record in the WAL type record struct { recordType uint8 data []byte } // readRecord reads a single physical record from the WAL func (r *Reader) readRecord() (*record, error) { // Read header header := make([]byte, HeaderSize) if _, err := io.ReadFull(r.reader, header); err != nil { return nil, err } // Parse header crc := binary.LittleEndian.Uint32(header[0:4]) length := binary.LittleEndian.Uint16(header[4:6]) recordType := header[6] // Validate record type if recordType < RecordTypeFull || recordType > RecordTypeLast { return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, recordType) } // Read payload data := make([]byte, length) if _, err := io.ReadFull(r.reader, data); err != nil { return nil, err } // Verify CRC computedCRC := crc32.ChecksumIEEE(data) if computedCRC != crc { return nil, fmt.Errorf("%w: expected CRC %d, got %d", ErrCorruptRecord, crc, computedCRC) } return &record{ recordType: recordType, data: data, }, nil } // processFragments combines fragments into a single entry func (r *Reader) processFragments() (*Entry, error) { // Determine total size totalSize := 0 for _, frag := range r.fragments { totalSize += len(frag) } // Combine fragments combined := make([]byte, totalSize) offset := 0 for _, frag := range r.fragments { copy(combined[offset:], frag) offset += len(frag) } // Reset fragments r.fragments = r.fragments[:0] // Parse the combined data into an entry return r.parseEntryData(combined) } // parseEntryData parses the binary data into an Entry structure func (r *Reader) parseEntryData(data []byte) (*Entry, error) { if len(data) < 13 { // Minimum size: type(1) + seq(8) + keylen(4) return nil, fmt.Errorf("%w: entry too small, %d bytes", ErrCorruptRecord, len(data)) } offset := 0 // Read entry type entryType := data[offset] offset++ // Validate entry type if entryType != OpTypePut && entryType != OpTypeDelete && entryType != OpTypeMerge && entryType != OpTypeBatch { return nil, fmt.Errorf("%w: %d", ErrInvalidOpType, entryType) } // Read sequence number seqNum := binary.LittleEndian.Uint64(data[offset : offset+8]) offset += 8 // Read key length keyLen := binary.LittleEndian.Uint32(data[offset : offset+4]) offset += 4 // Validate key length if offset+int(keyLen) > len(data) { return nil, fmt.Errorf("%w: invalid key length %d", ErrCorruptRecord, keyLen) } // Read key key := make([]byte, keyLen) copy(key, data[offset:offset+int(keyLen)]) offset += int(keyLen) // Read value if applicable var value []byte if entryType != OpTypeDelete { // Check if there's enough data for value length if offset+4 > len(data) { return nil, fmt.Errorf("%w: missing value length", ErrCorruptRecord) } // Read value length valueLen := binary.LittleEndian.Uint32(data[offset : offset+4]) offset += 4 // Validate value length if offset+int(valueLen) > len(data) { return nil, fmt.Errorf("%w: invalid value length %d", ErrCorruptRecord, valueLen) } // Read value value = make([]byte, valueLen) copy(value, data[offset:offset+int(valueLen)]) } return &Entry{ SequenceNumber: seqNum, Type: entryType, Key: key, Value: value, }, nil } // Close closes the reader func (r *Reader) Close() error { return r.file.Close() } // EntryHandler is a function that processes WAL entries during replay type EntryHandler func(*Entry) error // FindWALFiles returns a list of WAL files in the given directory func FindWALFiles(dir string) ([]string, error) { pattern := filepath.Join(dir, "*.wal") matches, err := filepath.Glob(pattern) if err != nil { return nil, fmt.Errorf("failed to glob WAL files: %w", err) } // Sort by filename (which should be timestamp-based) sort.Strings(matches) return matches, nil } // ReplayWALFile replays a single WAL file and calls the handler for each entry func ReplayWALFile(path string, handler EntryHandler) error { reader, err := OpenReader(path) if err != nil { return err } defer reader.Close() for { entry, err := reader.ReadEntry() if err != nil { if err == io.EOF { break } return fmt.Errorf("error reading entry from %s: %w", path, err) } if err := handler(entry); err != nil { return fmt.Errorf("error handling entry: %w", err) } } return nil } // ReplayWALDir replays all WAL files in the given directory in order func ReplayWALDir(dir string, handler EntryHandler) error { files, err := FindWALFiles(dir) if err != nil { return err } for _, file := range files { if err := ReplayWALFile(file, handler); err != nil { return err } } return nil }