// Package pagealloc is the Go reference implementation for db-01-storage-primitives.
//
// Page layout (16-byte header + payload, zero-padded to PAGE_SIZE):
//
//	offset  size  field
//	------  ----  -----
//	   0     8    magic = MAGIC (little-endian uint64)
//	   8     2    version = 1   (little-endian uint16)
//	  10     2    flags = 0     (little-endian uint16)
//	  12     4    payload_len   (little-endian uint32)
//	  16     n    payload bytes
//	n+16     -    zero-pad to PAGE_SIZE
//
// All I/O uses positional ReadAt/WriteAt (pread/pwrite) — never Read/Write+Seek.
package pagealloc

import (
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"os"
)

const (
	// PageSize matches the OS page on x86_64/ARM64 and the NVMe LBA.
	PageSize = 4096

	// PageMagic in little-endian on disk: 45 47 41 50 31 45 53 44.
	PageMagic uint64 = 0x44534531_50414745

	PageVersion uint16 = 1
	HeaderLen          = 16
	MaxPayload         = PageSize - HeaderLen
)

// Sentinel errors. Callers can use errors.Is to match.
var (
	ErrPayloadTooLarge = errors.New("payload too large")
	ErrPageMissing     = errors.New("page missing")
	ErrBadMagic        = errors.New("bad page magic")
	ErrBadPayloadLen   = errors.New("bad payload_len")
)

// EncodePage builds a fresh PageSize buffer containing the header + payload + zero padding.
func EncodePage(payload []byte) ([PageSize]byte, error) {
	var buf [PageSize]byte
	if len(payload) > MaxPayload {
		return buf, fmt.Errorf("%w: %d bytes > %d", ErrPayloadTooLarge, len(payload), MaxPayload)
	}
	binary.LittleEndian.PutUint64(buf[0:8], PageMagic)
	binary.LittleEndian.PutUint16(buf[8:10], PageVersion)
	binary.LittleEndian.PutUint16(buf[10:12], 0)
	binary.LittleEndian.PutUint32(buf[12:16], uint32(len(payload)))
	copy(buf[HeaderLen:], payload)
	return buf, nil
}

// DecodePage validates header and returns a slice into buf for the payload bytes.
func DecodePage(pageNo uint64, buf *[PageSize]byte) ([]byte, error) {
	magic := binary.LittleEndian.Uint64(buf[0:8])
	if magic != PageMagic {
		return nil, fmt.Errorf("%w: page %d: found 0x%016x expected 0x%016x", ErrBadMagic, pageNo, magic, PageMagic)
	}
	plen := int(binary.LittleEndian.Uint32(buf[12:16]))
	if plen > MaxPayload {
		return nil, fmt.Errorf("%w: page %d: len %d > %d", ErrBadPayloadLen, pageNo, plen, MaxPayload)
	}
	return buf[HeaderLen : HeaderLen+plen], nil
}

// WritePage writes payload into page pageNo of path, then fdatasync (or F_FULLFSYNC on macOS).
func WritePage(path string, pageNo uint64, payload []byte) error {
	buf, err := EncodePage(payload)
	if err != nil {
		return err
	}
	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0644)
	if err != nil {
		return err
	}
	defer f.Close()
	offset := int64(pageNo) * int64(PageSize)
	if _, err := f.WriteAt(buf[:], offset); err != nil {
		return err
	}
	return SyncData(f)
}

// ReadPage reads page pageNo from path and returns the payload bytes.
func ReadPage(path string, pageNo uint64) ([]byte, error) {
	f, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	defer f.Close()
	var buf [PageSize]byte
	offset := int64(pageNo) * int64(PageSize)
	n, err := f.ReadAt(buf[:], offset)
	if err == io.EOF && n == 0 {
		return nil, fmt.Errorf("%w: page %d", ErrPageMissing, pageNo)
	}
	if err != nil && err != io.EOF {
		return nil, err
	}
	if n < PageSize {
		// Short read at EOF — likely a partial last page; treat as missing for cleanliness.
		return nil, fmt.Errorf("%w: page %d (short read %d)", ErrPageMissing, pageNo, n)
	}
	payload, err := DecodePage(pageNo, &buf)
	if err != nil {
		return nil, err
	}
	out := make([]byte, len(payload))
	copy(out, payload)
	return out, nil
}

// Hexdump writes a canonical 16-bytes-per-line xxd-compatible dump of path to w.
func Hexdump(path string, w io.Writer) error {
	f, err := os.Open(path)
	if err != nil {
		return err
	}
	defer f.Close()
	buf := make([]byte, 4096)
	var offset uint64
	for {
		n, err := f.Read(buf)
		if n == 0 {
			if err == nil || err == io.EOF {
				return nil
			}
			return err
		}
		for chunkStart := 0; chunkStart < n; chunkStart += 16 {
			end := chunkStart + 16
			if end > n {
				end = n
			}
			line := buf[chunkStart:end]
			if _, err := fmt.Fprintf(w, "%08x: ", offset+uint64(chunkStart)); err != nil {
				return err
			}
			for i := 0; i < 16; i++ {
				if i < len(line) {
					if _, err := fmt.Fprintf(w, "%02x", line[i]); err != nil {
						return err
					}
				} else {
					if _, err := fmt.Fprint(w, "  "); err != nil {
						return err
					}
				}
				if i%2 == 1 {
					if _, err := fmt.Fprint(w, " "); err != nil {
						return err
					}
				}
			}
			if _, err := fmt.Fprint(w, " "); err != nil {
				return err
			}
			for _, b := range line {
				c := byte('.')
				if b >= 0x20 && b < 0x7f {
					c = b
				}
				if _, err := w.Write([]byte{c}); err != nil {
					return err
				}
			}
			if _, err := fmt.Fprintln(w); err != nil {
				return err
			}
		}
		offset += uint64(n)
		if err == io.EOF {
			return nil
		}
		if err != nil {
			return err
		}
	}
}
