// pagealloc CLI — Go reference for db-01-storage-primitives.
package main

import (
	"fmt"
	"os"
	"os/exec"
	"runtime"
	"sort"
	"strconv"
	"time"

	pagealloc "github.com/10xdev/dse/db01"
)

const usage = `usage: pagealloc <subcommand> [args]

subcommands:
  write   <file> <page_no> <string>        write a single page
  read    <file> <page_no>                 read a single page, print payload
  hexdump <file>                           dump the file in xxd-compatible hex
  bench   <file> <pages> <iters>           random pread benchmark (warm + cold)
`

func main() {
	if len(os.Args) < 2 {
		fmt.Fprint(os.Stderr, usage)
		os.Exit(2)
	}
	switch os.Args[1] {
	case "write":
		cmdWrite(os.Args[2:])
	case "read":
		cmdRead(os.Args[2:])
	case "hexdump":
		cmdHexdump(os.Args[2:])
	case "bench":
		cmdBench(os.Args[2:])
	default:
		fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", os.Args[1])
		fmt.Fprint(os.Stderr, usage)
		os.Exit(2)
	}
}

func cmdWrite(args []string) {
	if len(args) != 3 {
		fmt.Fprintln(os.Stderr, "usage: pagealloc write <file> <page_no> <string>")
		os.Exit(2)
	}
	pageNo, err := strconv.ParseUint(args[1], 10, 64)
	if err != nil {
		fmt.Fprintf(os.Stderr, "bad page_no: %v\n", err)
		os.Exit(2)
	}
	if err := pagealloc.WritePage(args[0], pageNo, []byte(args[2])); err != nil {
		fmt.Fprintf(os.Stderr, "write_page: %v\n", err)
		os.Exit(1)
	}
}

func cmdRead(args []string) {
	if len(args) != 2 {
		fmt.Fprintln(os.Stderr, "usage: pagealloc read <file> <page_no>")
		os.Exit(2)
	}
	pageNo, err := strconv.ParseUint(args[1], 10, 64)
	if err != nil {
		fmt.Fprintf(os.Stderr, "bad page_no: %v\n", err)
		os.Exit(2)
	}
	b, err := pagealloc.ReadPage(args[0], pageNo)
	if err != nil {
		fmt.Fprintf(os.Stderr, "read_page: %v\n", err)
		os.Exit(1)
	}
	os.Stdout.Write(b)
	os.Stdout.Write([]byte{'\n'})
}

func cmdHexdump(args []string) {
	if len(args) != 1 {
		fmt.Fprintln(os.Stderr, "usage: pagealloc hexdump <file>")
		os.Exit(2)
	}
	if err := pagealloc.Hexdump(args[0], os.Stdout); err != nil {
		fmt.Fprintf(os.Stderr, "hexdump: %v\n", err)
		os.Exit(1)
	}
}

func cmdBench(args []string) {
	if len(args) != 3 {
		fmt.Fprintln(os.Stderr, "usage: pagealloc bench <file> <pages> <iters>")
		os.Exit(2)
	}
	path := args[0]
	pages, _ := strconv.ParseUint(args[1], 10, 64)
	iters, _ := strconv.ParseUint(args[2], 10, 64)

	fmt.Fprintf(os.Stderr, "preallocating %d pages = %d KiB\n", pages, pages*pagealloc.PageSize/1024)
	for p := uint64(0); p < pages; p++ {
		payload := fmt.Sprintf("page-%d", p)
		if err := pagealloc.WritePage(path, p, []byte(payload)); err != nil {
			fmt.Fprintf(os.Stderr, "preallocate: %v\n", err)
			os.Exit(1)
		}
	}

	fmt.Fprintln(os.Stderr, "warm-cache pass")
	warm := runPass(path, pages, iters)
	printStats("WARM", iters, warm)

	dropPageCache()

	fmt.Fprintln(os.Stderr, "cold-cache pass")
	cold := runPass(path, pages, iters)
	printStats("COLD", iters, cold)
}

func runPass(path string, pages, iters uint64) []time.Duration {
	state := uint64(0x9E3779B97F4A7C15) ^ iters
	lats := make([]time.Duration, 0, iters)
	for i := uint64(0); i < iters; i++ {
		state ^= state >> 12
		state ^= state << 25
		state ^= state >> 27
		target := (state * 0x2545F4914F6CDD1D) % pages
		t0 := time.Now()
		if _, err := pagealloc.ReadPage(path, target); err != nil {
			fmt.Fprintf(os.Stderr, "bench read: %v\n", err)
			os.Exit(1)
		}
		lats = append(lats, time.Since(t0))
	}
	return lats
}

func printStats(label string, iters uint64, lats []time.Duration) {
	sort.Slice(lats, func(i, j int) bool { return lats[i] < lats[j] })
	p50 := lats[len(lats)/2]
	p99 := lats[(len(lats)*99)/100]
	p999 := lats[(len(lats)*999)/1000]
	var total time.Duration
	for _, l := range lats {
		total += l
	}
	totalBytes := float64(iters) * float64(pagealloc.PageSize)
	mbps := totalBytes / total.Seconds() / (1024 * 1024)
	fmt.Printf("%s cache:\n", label)
	fmt.Printf("  iterations : %d\n", iters)
	fmt.Printf("  p50        : %d µs\n", p50.Microseconds())
	fmt.Printf("  p99        : %d µs\n", p99.Microseconds())
	fmt.Printf("  p99.9      : %d µs\n", p999.Microseconds())
	fmt.Printf("  throughput : %.0f MB/s\n", mbps)
}

func dropPageCache() {
	switch runtime.GOOS {
	case "linux":
		_ = exec.Command("sync").Run()
		err := exec.Command("sh", "-c", "echo 3 > /proc/sys/vm/drop_caches").Run()
		if err == nil {
			fmt.Fprintln(os.Stderr, "dropped page cache (Linux)")
		} else {
			fmt.Fprintln(os.Stderr, "could not drop page cache (run as root for true cold pass)")
		}
	case "darwin":
		err := exec.Command("sudo", "-n", "purge").Run()
		if err == nil {
			fmt.Fprintln(os.Stderr, "dropped page cache (macOS purge)")
		} else {
			fmt.Fprintln(os.Stderr, "could not drop page cache (run as admin: `sudo purge`)")
		}
	default:
		fmt.Fprintln(os.Stderr, "page cache drop not implemented on this OS")
	}
}
