//
// Syd: rock-solid application kernel
// src/rng.rs: OS Random Number Generator (RNG) interface
//
// Copyright (c) 2023, 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

//! Set of functions to manage the OS Random Number Generator (RNG)

use std::{
    ops::RangeInclusive,
    os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd},
};

use libc::{c_int, dup3, GRND_RANDOM};
use memchr::memchr;
use nix::{
    errno::Errno,
    fcntl::{OFlag, ResolveFlag},
    sys::resource::{getrlimit, Resource},
    unistd::{close, UnlinkatFlags},
};

use crate::{
    cookie::safe_unlinkat,
    fs::is_active_fd,
    log::{now, Tm},
    lookup::safe_open,
    path::{XPathBuf, PATH_MAX},
    retry::retry_on_eintr,
};

/// RAII guard that disables pthread cancellation for the current thread
/// and restores the previous state on drop. Uses pthread_setcancelstate(3).
#[must_use = "hold the guard to keep cancellation disabled"]
pub struct CancelGuard(c_int);

const _PTHREAD_CANCEL_ENABLE: c_int = 0;
const PTHREAD_CANCEL_DISABLE: c_int = 1;

// Libc crate does not define this symbol explicitly yet.
extern "C" {
    fn pthread_setcancelstate(state: c_int, oldstate: *mut c_int) -> c_int;
}

impl CancelGuard {
    /// Acquire the guard by disabling pthread cancellation for this thread.
    ///
    /// Returns a guard that will restore the previous state when dropped.
    pub fn acquire() -> Result<Self, Errno> {
        let mut old: c_int = 0;

        // SAFETY: We call pthread_setcancelstate(3) for the current thread.
        // - PTHREAD_CANCEL_DISABLE is a valid constant.
        // - Second arg is a valid, writable pointer to store the previous state.
        // - This does not move or alias Rust values; it only flips the thread-local flag.
        let err = unsafe { pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &raw mut old) };

        // POSIX returns error code directly (not via errno).
        if err == 0 {
            Ok(Self(old))
        } else {
            Err(Errno::from_raw(err))
        }
    }
}

impl Drop for CancelGuard {
    fn drop(&mut self) {
        // SAFETY: Restore the exact state captured at construction
        // for the current thread. The second parameter can be NULL
        // when we don't care about the previous value.
        unsafe {
            pthread_setcancelstate(self.0, std::ptr::null_mut());
        }
    }
}

/// Public trait for unsigned integers that support uniform sampling without widening.
pub trait RandUint: Copy + Ord {
    /// Additive zero.
    const ZERO: Self;
    /// Additive one.
    const ONE: Self;
    /// Maximum value.
    const MAX: Self;

    /// Draw a uniformly random value of this type using the OS RNG for exactly this width.
    fn rand_from_os() -> Result<Self, Errno>;

    /// Checked add returning None on overflow.
    fn checked_add(self, rhs: Self) -> Option<Self>;
    /// Checked sub returning None on underflow.
    fn checked_sub(self, rhs: Self) -> Option<Self>;
    /// Checked mul returning None on overflow.
    fn checked_mul(self, rhs: Self) -> Option<Self>;

    /// Euclidean division returning None if rhs is zero.
    fn div_euclid_opt(self, rhs: Self) -> Option<Self>;
}

macro_rules! impl_rand_uint {
    ($($t:ty),* $(,)?) => {$(
        impl RandUint for $t {
            const ZERO: Self = 0;
            const ONE: Self = 1;
            const MAX: Self = <$t>::MAX;

            #[inline]
            fn rand_from_os() -> Result<Self, Errno> {
                // Read exactly size_of::<$t>() bytes, and interpret in native endian.
                let mut buf = [0u8; { size_of::<$t>() }];
                fillrandom(&mut buf)?;
                Ok(<$t>::from_ne_bytes(buf))
            }

            #[inline] fn checked_add(self, rhs: Self) -> Option<Self> { self.checked_add(rhs) }
            #[inline] fn checked_sub(self, rhs: Self) -> Option<Self> { self.checked_sub(rhs) }
            #[inline] fn checked_mul(self, rhs: Self) -> Option<Self> { self.checked_mul(rhs) }

            #[inline]
            fn div_euclid_opt(self, rhs: Self) -> Option<Self> {
                if rhs == 0 { None } else { Some(self.div_euclid(rhs)) }
            }
        }
    )*};
}
impl_rand_uint!(u8, u16, u32, u64, u128, usize);

/// Return a uniform random unsigned integer in the inclusive range,
/// using OS randomness with rejection sampling.
pub fn randint<T>(range: RangeInclusive<T>) -> Result<T, Errno>
where
    T: RandUint,
{
    let (lo, hi) = range.into_inner();

    // Reject inverted or one-point ranges as invalid input.
    if lo >= hi {
        return Err(Errno::EINVAL);
    }

    // Full-domain path returns raw OS bytes for exact type width.
    if lo == T::ZERO && hi == T::MAX {
        return T::rand_from_os();
    }

    // Compute span = (hi - lo) + 1 with checked ops to avoid overflow.
    let span = hi
        .checked_sub(lo)
        .ok_or(Errno::EOVERFLOW)?
        .checked_add(T::ONE)
        .ok_or(Errno::EOVERFLOW)?;

    // Compute accept_top = floor(MAX / span) * span,
    // using only checked ops and Euclidean division.
    let k = T::MAX.div_euclid_opt(span).ok_or(Errno::EOVERFLOW)?;
    let accept_top = k.checked_mul(span).ok_or(Errno::EOVERFLOW)?;

    // Draw until r < accept_top so the mapping is unbiased.
    loop {
        let r = T::rand_from_os()?;
        if r < accept_top {
            // Compute off = r - floor(r / span) * span without remainder operators.
            let q = r.div_euclid_opt(span).ok_or(Errno::EOVERFLOW)?;
            let qspan = q.checked_mul(span).ok_or(Errno::EOVERFLOW)?;
            let off = r.checked_sub(qspan).ok_or(Errno::EOVERFLOW)?;
            let v = lo.checked_add(off).ok_or(Errno::EOVERFLOW)?;
            return Ok(v);
        }
    }
}

/// Return a random unprivileged port number using the OS RNG.
#[inline]
pub fn randport() -> Result<u16, Errno> {
    randint(1025u16..=0xFFFF)
}

/// Get secure bytes using the OS random number generator.
pub fn getrandom(size: usize) -> Result<Vec<u8>, Errno> {
    if size == 0 {
        // SAFETY:
        // Return EINVAL on zero length which is a common case of error.
        return Err(Errno::EINVAL);
    }

    let mut buf = Vec::new();
    if buf.try_reserve(size).is_err() {
        return Err(Errno::ENOMEM);
    }
    buf.resize(size, 0);

    fillrandom(&mut buf)?;
    Ok(buf)
}

/// Fill the given buffer using the OS random number generator.
pub fn fillrandom(buf: &mut [u8]) -> Result<(), Errno> {
    // SAFETY: Ensure buffer is not empty,
    // which is a common case of error.
    let siz = buf.len();
    if siz == 0 {
        return Err(Errno::EINVAL);
    }

    // Disable pthread cancellation within this critical section.
    // Restored automatically when guard is dropped.
    let guard = CancelGuard::acquire()?;

    let mut n = 0;
    while n < siz {
        let ptr = &mut buf[n..];
        let ptr = ptr.as_mut_ptr().cast();
        let siz = siz.checked_sub(n).ok_or(Errno::EOVERFLOW)?;

        n = n
            .checked_add(
                retry_on_eintr(|| {
                    // SAFETY: In libc we trust.
                    Errno::result(unsafe { libc::getrandom(ptr, siz, GRND_RANDOM) })
                })?
                .try_into()
                .or(Err(Errno::EINVAL))?,
            )
            .ok_or(Errno::EOVERFLOW)?;
    }

    // End of critical section.
    drop(guard);

    Ok(())
}

/// Duplicate the file descriptor to a random fd.
///
/// Valid flags:
/// - O_EXCL: closes oldfd after successful duplication.
/// - All other flags are passed to dup3(2), ie O_CLOEXEC.
pub fn duprand(oldfd: RawFd, mut flags: OFlag) -> Result<RawFd, Errno> {
    let range_start = 7u64;
    let (range_end, _) = getrlimit(Resource::RLIMIT_NOFILE)?;
    #[expect(clippy::unnecessary_cast)]
    let range_end = range_end.saturating_sub(1) as u64;

    // SAFETY: Cap to a sane maximum because sufficiently big values
    // of the hard limit tend to return ENOMEM.
    let range_end = range_end.min(0x10000);
    if range_end <= range_start {
        return Err(Errno::EMFILE);
    }
    let range = range_start..=range_end;

    // Close old fd if O_EXCL is given,
    // pass the rest of the flags to dup3.
    let close_old = flags.contains(OFlag::O_EXCL);
    flags.remove(OFlag::O_EXCL);

    // SAFETY: To make this file descriptor harder to spot by an
    // attacker we duplicate it to a random fd number.
    for _ in range.clone() {
        #[expect(clippy::cast_possible_truncation)]
        let fd_rand = randint(range.clone())? as RawFd;

        // SAFETY: fd only used after validation.
        let fd_rand = unsafe { BorrowedFd::borrow_raw(fd_rand) };

        // Check if the slot is free.
        // This is arguably subject to race but since this is solely
        // used for fds at startup, we don't really care.
        if is_active_fd(fd_rand) {
            continue;
        }

        match retry_on_eintr(|| {
            // SAFETY: In libc we trust.
            Errno::result(unsafe { dup3(oldfd, fd_rand.as_raw_fd(), flags.bits()) })
        }) {
            Ok(_) => {
                if close_old {
                    let _ = close(oldfd);
                }
                return Ok(fd_rand.as_raw_fd());
            }
            Err(Errno::EMFILE) => return Err(Errno::EMFILE),
            Err(_) => {}
        }
    }

    Err(Errno::EBADF)
}

/// Create a unique temporary file in `dirfd` relative to `prefix`
/// unlink the file and return its file descriptor. Unlike libc's
/// mkstemp(3) function the template here does not have to end with any
/// number of `X` characters. The function appends an implementation
/// defined number of random characters after `prefix`. `prefix` must
/// not contain the `/` character and not be longer than `PATH_MAX`
/// characters long. It is OK for prefix to be empty.
/// If `dirfd` supports the `O_TMPFILE` operation, an unnamed temporary
/// file is created instead with `O_TMPFILE|O_EXCL`.
pub fn mkstempat<Fd: AsFd>(dirfd: Fd, prefix: &[u8]) -> Result<OwnedFd, Errno> {
    const MAX_TCOUNT: usize = 8;
    const SUFFIX_LEN: usize = 128;
    const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";

    // Step 1: Attempt to use O_TMPFILE|O_EXCL which is safer.
    let mut flags = OFlag::O_TMPFILE | OFlag::O_EXCL | OFlag::O_RDWR;
    match safe_open(&dirfd, c".", flags, ResolveFlag::empty()) {
        Ok(fd) => return Ok(fd),
        Err(Errno::EISDIR | Errno::ENOENT | Errno::EOPNOTSUPP) => {}
        Err(errno) => return Err(errno),
    }

    // Step 2: Fallback to random name generation.
    flags.remove(OFlag::O_TMPFILE);
    flags.insert(OFlag::O_CREAT);
    if memchr(b'/', prefix).is_some() {
        return Err(Errno::EINVAL);
    } else if prefix.len().saturating_sub(SUFFIX_LEN) > PATH_MAX {
        return Err(Errno::ENAMETOOLONG);
    }

    let mut attempts = 0;
    let mut rng_data = [0u8; SUFFIX_LEN];
    loop {
        attempts = attempts.checked_add(1).ok_or(Errno::EOVERFLOW)?;
        if attempts > MAX_TCOUNT {
            // Too many collisions.
            return Err(Errno::EEXIST);
        }

        // Fill with random bytes.
        fillrandom(&mut rng_data)?;

        // Map bytes to characters.
        let size = prefix
            .len()
            .checked_add(SUFFIX_LEN)
            .ok_or(Errno::EOVERFLOW)?;
        let mut base = XPathBuf::new();
        base.try_reserve(size).or(Err(Errno::ENOMEM))?;
        base.append_bytes(prefix);
        for &b in &rng_data {
            let idx = (b as usize)
                .checked_rem(CHARSET.len())
                .ok_or(Errno::EOVERFLOW)?;
            let chr = CHARSET.get(idx).copied().ok_or(Errno::EOVERFLOW)?;
            base.append_byte(chr);
        }

        match safe_open(&dirfd, &base, flags, ResolveFlag::empty()) {
            Ok(fd) => {
                safe_unlinkat(dirfd, &base, UnlinkatFlags::NoRemoveDir)?;
                return Ok(fd);
            }
            Err(Errno::EEXIST) => {
                // Try again with a new random sequence.
                continue;
            }
            Err(errno) => return Err(errno),
        }
    }
}

/// Generate a random Linux kernel version string.
pub fn rand_version() -> Result<String, Errno> {
    const VERMAGICS: &[&str] = &[
        "SMP",
        "SMP PREEMPT",
        "SMP PREEMPT_DYNAMIC",
        "SMP PREEMPT_RT",
    ];
    const MONTHS: &[&str] = &[
        "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
    ];
    const WKDAYS: &[&str] = &["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];

    // Subtract a random number of seconds within ~last year.
    const TS_WINDOW: u64 = 366 * 86_400;
    let now = now();
    let offset: u64 = randint(0..=TS_WINDOW)?;
    let target: i64 = now
        .saturating_sub(offset)
        .try_into()
        .or(Err(Errno::EOVERFLOW))?;

    // Break down that instant.
    let tm = Tm::try_from(target)?;

    // Randomize build number and PREEMPT variant.
    let build_no = randint(1u8..=64)?;
    #[expect(clippy::arithmetic_side_effects)]
    let vermagic = VERMAGICS[randint(0usize..=(VERMAGICS.len() - 1))?];

    // Determine version month and day.
    #[expect(clippy::arithmetic_side_effects)]
    #[expect(clippy::cast_sign_loss)]
    let mon = MONTHS[(tm.month() - 1) as usize];
    #[expect(clippy::cast_sign_loss)]
    let wday = WKDAYS[tm.weekday() as usize];
    let mday = tm.day();
    let hh = tm.hour();
    let mm = tm.minute();
    let ss = tm.second();
    let year = tm.year();

    Ok(format!(
        "#{build_no} {vermagic} {wday} {mon} {mday:>2} {hh:02}:{mm:02}:{ss:02} UTC {year}",
    ))
}

#[cfg(test)]
mod tests {
    use std::fmt::Debug;

    use super::*;

    // Check basic API functions for sanity

    #[test]
    fn test_fillrandom() {
        assert_eq!(fillrandom(&mut []), Err(Errno::EINVAL));

        assert_eq!(fillrandom(&mut [0u8; 257]), Ok(()));
    }

    #[test]
    fn test_getrandom() {
        assert_eq!(getrandom(0), Err(Errno::EINVAL));

        let result = getrandom(257);
        assert!(result.is_ok(), "result:{result:?}");
    }

    // Test helpers

    fn draw<T: RandUint + Debug>(lo: T, hi: T) -> T {
        match randint::<T>(lo..=hi) {
            Ok(v) => v,
            Err(e) => panic!("randint failed for [{:?},{:?}] -> {:?}", lo, hi, e),
        }
    }

    fn sample<T: RandUint + Debug>(lo: T, hi: T, n: usize) -> Vec<T> {
        (0..n).map(|_| draw::<T>(lo, hi)).collect()
    }

    fn all_in_range<T: RandUint + Debug>(xs: &[T], lo: T, hi: T) -> bool {
        xs.iter().all(|&v| v >= lo && v <= hi)
    }

    // API checks

    #[test]
    fn test_randint_invalid_u8() {
        assert!(matches!(randint::<u8>(200..=100), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u16() {
        assert!(matches!(randint::<u16>(5000..=4999), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u32() {
        assert!(matches!(randint::<u32>(42..=41), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u64() {
        assert!(matches!(randint::<u64>(999..=998), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u128() {
        assert!(matches!(randint::<u128>(500..=499), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_usize() {
        assert!(matches!(randint::<usize>(100..=99), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_u8() {
        assert!(matches!(randint::<u8>(77..=77), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_u16() {
        assert!(matches!(randint::<u16>(31337..=31337), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_u32() {
        assert!(matches!(
            randint::<u32>(1_000_000..=1_000_000),
            Err(Errno::EINVAL)
        ));
    }

    #[test]
    fn test_randint_onepoint_u64() {
        assert!(matches!(
            randint::<u64>(123456789..=123456789),
            Err(Errno::EINVAL)
        ));
    }

    #[test]
    fn test_randint_onepoint_u128() {
        assert!(matches!(randint::<u128>(999..=999), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_usize() {
        assert!(matches!(randint::<usize>(4242..=4242), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_fulldomain_u8_inbounds() {
        let xs = sample::<u8>(u8::MIN, u8::MAX, 4096);
        assert!(all_in_range(&xs, u8::MIN, u8::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u16_inbounds() {
        let xs = sample::<u16>(u16::MIN, u16::MAX, 2048);
        assert!(all_in_range(&xs, u16::MIN, u16::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u32_inbounds() {
        let xs = sample::<u32>(u32::MIN, u32::MAX, 2048);
        assert!(all_in_range(&xs, u32::MIN, u32::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u64_inbounds() {
        let xs = sample::<u64>(u64::MIN, u64::MAX, 1024);
        assert!(all_in_range(&xs, u64::MIN, u64::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u128_inbounds() {
        let xs = sample::<u128>(u128::MIN, u128::MAX, 256);
        assert!(all_in_range(&xs, u128::MIN, u128::MAX));
    }

    #[test]
    fn test_randint_fulldomain_usize_inbounds() {
        let xs = sample::<usize>(usize::MIN, usize::MAX, 1024);
        assert!(all_in_range(&xs, usize::MIN, usize::MAX));
    }

    #[test]
    fn test_randint_u8_nearmax_inbounds() {
        let lo = u8::MAX.saturating_sub(15);
        let xs = sample::<u8>(lo, u8::MAX, 2000);
        assert!(all_in_range(&xs, lo, u8::MAX));
    }

    #[test]
    fn test_randint_u16_nearmax_inbounds() {
        let lo = u16::MAX.saturating_sub(1023);
        let xs = sample::<u16>(lo, u16::MAX, 4000);
        assert!(all_in_range(&xs, lo, u16::MAX));
    }

    #[test]
    fn test_randint_u32_nearmax_inbounds() {
        let lo = u32::MAX.saturating_sub(1000);
        let xs = sample::<u32>(lo, u32::MAX, 3000);
        assert!(all_in_range(&xs, lo, u32::MAX));
    }

    #[test]
    fn test_randint_u64_nearmax_inbounds() {
        let lo = u64::MAX.saturating_sub(1000);
        let xs = sample::<u64>(lo, u64::MAX, 3000);
        assert!(all_in_range(&xs, lo, u64::MAX));
    }

    #[test]
    fn test_randint_u128_nearmax_inbounds() {
        let lo = u128::MAX.saturating_sub(1000);
        let xs = sample::<u128>(lo, u128::MAX, 2000);
        assert!(all_in_range(&xs, lo, u128::MAX));
    }

    #[test]
    fn test_randint_usize_nearmax_inbounds() {
        let lo = usize::MAX.saturating_sub(1000);
        let xs = sample::<usize>(lo, usize::MAX, 3000);
        assert!(all_in_range(&xs, lo, usize::MAX));
    }
}
