#include "pagealloc.h"

#include <fcntl.h>
#include <unistd.h>

#include <cerrno>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <ostream>

namespace dse::pagealloc
{

    namespace
    {

        class Category final : public std::error_category
        {
        public:
            const char *name() const noexcept override { return "dse::pagealloc"; }
            std::string message(int ev) const override
            {
                switch (static_cast<Errc>(ev))
                {
                case Errc::Ok:
                    return "ok";
                case Errc::PayloadTooLarge:
                    return "payload too large";
                case Errc::PageMissing:
                    return "page missing";
                case Errc::BadMagic:
                    return "bad page magic";
                case Errc::BadPayloadLen:
                    return "bad payload_len";
                }
                return "unknown";
            }
        };

        const Category &category_instance()
        {
            static const Category c;
            return c;
        }

        void put_u64_le(std::uint8_t *p, std::uint64_t v) noexcept
        {
            for (int i = 0; i < 8; ++i)
                p[i] = static_cast<std::uint8_t>(v >> (8 * i));
        }
        void put_u32_le(std::uint8_t *p, std::uint32_t v) noexcept
        {
            for (int i = 0; i < 4; ++i)
                p[i] = static_cast<std::uint8_t>(v >> (8 * i));
        }
        void put_u16_le(std::uint8_t *p, std::uint16_t v) noexcept
        {
            p[0] = static_cast<std::uint8_t>(v);
            p[1] = static_cast<std::uint8_t>(v >> 8);
        }
        std::uint64_t get_u64_le(const std::uint8_t *p) noexcept
        {
            std::uint64_t v = 0;
            for (int i = 0; i < 8; ++i)
                v |= static_cast<std::uint64_t>(p[i]) << (8 * i);
            return v;
        }
        std::uint32_t get_u32_le(const std::uint8_t *p) noexcept
        {
            std::uint32_t v = 0;
            for (int i = 0; i < 4; ++i)
                v |= static_cast<std::uint32_t>(p[i]) << (8 * i);
            return v;
        }

        std::error_code errno_ec()
        {
            return {errno, std::system_category()};
        }

        // RAII fd wrapper.
        struct Fd
        {
            int fd = -1;
            ~Fd()
            {
                if (fd >= 0)
                    ::close(fd);
            }
            Fd() = default;
            Fd(int f) : fd(f) {}
            Fd(const Fd &) = delete;
            Fd &operator=(const Fd &) = delete;
        };

    } // namespace

    const std::error_category &pagealloc_category() noexcept
    {
        return category_instance();
    }

    bool EncodePage(std::span<const std::uint8_t> payload,
                    std::array<std::uint8_t, kPageSize> &out,
                    std::error_code &ec) noexcept
    {
        if (payload.size() > kMaxPayload)
        {
            ec = make_error_code(Errc::PayloadTooLarge);
            return false;
        }
        out.fill(0);
        put_u64_le(out.data() + 0, kPageMagic);
        put_u16_le(out.data() + 8, kPageVersion);
        put_u16_le(out.data() + 10, 0);
        put_u32_le(out.data() + 12, static_cast<std::uint32_t>(payload.size()));
        std::memcpy(out.data() + kHeaderLen, payload.data(), payload.size());
        ec = {};
        return true;
    }

    bool DecodePage(std::uint64_t page_no,
                    const std::array<std::uint8_t, kPageSize> &buf,
                    std::vector<std::uint8_t> &out,
                    std::error_code &ec) noexcept
    {
        const std::uint64_t magic = get_u64_le(buf.data());
        if (magic != kPageMagic)
        {
            ec = make_error_code(Errc::BadMagic);
            (void)page_no; // page_no included for parity with other languages
            return false;
        }
        const std::uint32_t plen = get_u32_le(buf.data() + 12);
        if (plen > kMaxPayload)
        {
            ec = make_error_code(Errc::BadPayloadLen);
            return false;
        }
        out.assign(buf.data() + kHeaderLen, buf.data() + kHeaderLen + plen);
        ec = {};
        return true;
    }

    bool WritePage(const std::string &path,
                   std::uint64_t page_no,
                   std::span<const std::uint8_t> payload,
                   std::error_code &ec) noexcept
    {
        std::array<std::uint8_t, kPageSize> buf{};
        if (!EncodePage(payload, buf, ec))
            return false;
        Fd f{::open(path.c_str(), O_RDWR | O_CREAT, 0644)};
        if (f.fd < 0)
        {
            ec = errno_ec();
            return false;
        }
        const off_t offset = static_cast<off_t>(page_no) * static_cast<off_t>(kPageSize);
        std::size_t written = 0;
        while (written < kPageSize)
        {
            ssize_t n = ::pwrite(f.fd, buf.data() + written, kPageSize - written,
                                 offset + static_cast<off_t>(written));
            if (n < 0)
            {
                if (errno == EINTR)
                    continue;
                ec = errno_ec();
                return false;
            }
            written += static_cast<std::size_t>(n);
        }
        if (FsyncData(f.fd) != 0)
        {
            ec = errno_ec();
            return false;
        }
        ec = {};
        return true;
    }

    bool ReadPage(const std::string &path,
                  std::uint64_t page_no,
                  std::vector<std::uint8_t> &out,
                  std::error_code &ec) noexcept
    {
        Fd f{::open(path.c_str(), O_RDONLY)};
        if (f.fd < 0)
        {
            ec = errno_ec();
            return false;
        }
        std::array<std::uint8_t, kPageSize> buf{};
        const off_t offset = static_cast<off_t>(page_no) * static_cast<off_t>(kPageSize);
        std::size_t got = 0;
        while (got < kPageSize)
        {
            ssize_t n = ::pread(f.fd, buf.data() + got, kPageSize - got,
                                offset + static_cast<off_t>(got));
            if (n < 0)
            {
                if (errno == EINTR)
                    continue;
                ec = errno_ec();
                return false;
            }
            if (n == 0)
            { // EOF before full page
                ec = make_error_code(Errc::PageMissing);
                return false;
            }
            got += static_cast<std::size_t>(n);
        }
        return DecodePage(page_no, buf, out, ec);
    }

    bool Hexdump(const std::string &path, std::ostream &os, std::error_code &ec) noexcept
    {
        std::ifstream in(path, std::ios::binary);
        if (!in)
        {
            ec = std::make_error_code(std::errc::no_such_file_or_directory);
            return false;
        }
        std::array<char, 4096> buf{};
        std::uint64_t offset = 0;
        while (in)
        {
            in.read(buf.data(), static_cast<std::streamsize>(buf.size()));
            const std::streamsize n = in.gcount();
            if (n <= 0)
                break;
            for (std::streamsize chunk = 0; chunk < n; chunk += 16)
            {
                const std::streamsize end = std::min<std::streamsize>(chunk + 16, n);
                char header[16];
                std::snprintf(header, sizeof(header), "%08llx: ",
                              static_cast<unsigned long long>(offset + static_cast<std::uint64_t>(chunk)));
                os << header;
                for (int i = 0; i < 16; ++i)
                {
                    if (chunk + i < end)
                    {
                        char hex[3];
                        std::snprintf(hex, sizeof(hex), "%02x",
                                      static_cast<unsigned char>(buf[chunk + i]));
                        os << hex;
                    }
                    else
                    {
                        os << "  ";
                    }
                    if (i % 2 == 1)
                        os << ' ';
                }
                os << ' ';
                for (std::streamsize i = chunk; i < end; ++i)
                {
                    const unsigned char b = static_cast<unsigned char>(buf[i]);
                    os << static_cast<char>((b >= 0x20 && b < 0x7f) ? b : '.');
                }
                os << '\n';
            }
            offset += static_cast<std::uint64_t>(n);
        }
        ec = {};
        return true;
    }

} // namespace dse::pagealloc
