//
// Syd: rock-solid application kernel
// src/mask.rs: Utilities to mask sensitive information in proc files
//
// Copyright (c) 2025 Ali Polatel <alip@chesswob.org>
// SPDX-License-Identifier: GPL-3.0

// SAFETY: This module has been liberated from unsafe code!
#![forbid(unsafe_code)]

use std::{mem::take, os::fd::AsFd};

use memchr::{memchr, memmem};
use nix::{errno::Errno, unistd::read};

use crate::{io::write_all, retry::retry_on_eintr};

//
// Data types
//
struct Patch {
    needle: &'static [u8],
    repl: &'static [u8],
}
type PatchMask = u8; // up to 8 patches per group
type PatchStep = Option<(usize, PatchMask)>; // (bytes consumed, new mask)

//
// Field Prefixes
//
const TRACERPID: &[u8] = b"TracerPid:";
const NONEWPRIVS: &[u8] = b"NoNewPrivs:";
const SECCOMP: &[u8] = b"Seccomp:";
const SECCOMP_FILTERS: &[u8] = b"Seccomp_filters:";

//
// Speculation Prefixes
//
const SPEC_SSB: &[u8] = b"Speculation_Store_Bypass:";
const SPEC_SIB: &[u8] = b"SpeculationIndirectBranch:";

//
// Patch Tables
//
// SSB: Normalize to least-safe wording.
const SPEC_SSB_PATCHES: &[Patch] = &[
    Patch {
        needle: b"unknown",
        repl: b"vulnerable",
    },
    Patch {
        needle: b"unsupported",
        repl: b"vulnerable",
    },
    Patch {
        needle: b"thread ",
        repl: b"",
    },
    Patch {
        needle: b"force ",
        repl: b"",
    },
    Patch {
        needle: b"mitigated",
        repl: b"vulnerable",
    },
];

// SIB: Make always+enabled.
const SPEC_SIB_PATCHES: &[Patch] = &[
    Patch {
        needle: b"unknown",
        repl: b"always enabled",
    },
    Patch {
        needle: b"unsupported",
        repl: b"always enabled",
    },
    Patch {
        needle: b"conditional",
        repl: b"always",
    },
    Patch {
        needle: b"force ",
        repl: b"",
    },
    Patch {
        needle: b"disabled",
        repl: b"enabled",
    },
];

//
// Prefix Flags (u8):
// Headers do not repeat in /proc/*/status.
//
const PF_TRACERPID: u8 = 1 << 0;
const PF_NONEWPRIVS: u8 = 1 << 1;
const PF_SECCOMP: u8 = 1 << 2;
const PF_SECCOMP_FILTERS: u8 = 1 << 3;
const PF_SPEC_SSB: u8 = 1 << 4;
const PF_SPEC_SIB: u8 = 1 << 5;

//
// proc_pid_status(5) Masker
//
struct ProcPidStatusMasker {
    inbuf: Vec<u8>,  // whole-file input
    outbuf: Vec<u8>, // whole-file output
    prefix_mask: u8, // which headers we have already matched
}

impl ProcPidStatusMasker {
    const INBUF_CAP: usize = 2048;
    const OUTBUF_CAP: usize = 2048;
    const GROW_STEP: usize = 128;

    fn new() -> Result<Self, Errno> {
        let mut inbuf = Vec::new();
        inbuf
            .try_reserve(Self::INBUF_CAP)
            .map_err(|_| Errno::ENOMEM)?;
        let mut outbuf = Vec::new();
        outbuf
            .try_reserve(Self::OUTBUF_CAP)
            .map_err(|_| Errno::ENOMEM)?;
        Ok(Self {
            inbuf,
            outbuf,
            prefix_mask: 0,
        })
    }

    fn obuf_write(&mut self, data: &[u8]) -> Result<(), Errno> {
        if data.is_empty() {
            return Ok(());
        }
        self.outbuf
            .try_reserve(data.len())
            .map_err(|_| Errno::ENOMEM)?;
        self.outbuf.extend_from_slice(data);
        Ok(())
    }

    #[inline]
    fn flush_all<Fd: AsFd>(&mut self, out: Fd) -> Result<(), Errno> {
        // single write(2) syscall for the whole file.
        write_all(&out, &self.outbuf)
    }

    // Check zero-able fields.
    fn try_emit_zero_field(&mut self, line: &[u8]) -> Result<bool, Errno> {
        if (self.prefix_mask & PF_TRACERPID) == 0 && line.starts_with(TRACERPID) {
            self.prefix_mask |= PF_TRACERPID;
            self.emit_zero_field(line, TRACERPID)?;
            return Ok(true);
        }
        if (self.prefix_mask & PF_NONEWPRIVS) == 0 && line.starts_with(NONEWPRIVS) {
            self.prefix_mask |= PF_NONEWPRIVS;
            self.emit_zero_field(line, NONEWPRIVS)?;
            return Ok(true);
        }
        if (self.prefix_mask & PF_SECCOMP_FILTERS) == 0 && line.starts_with(SECCOMP_FILTERS) {
            self.prefix_mask |= PF_SECCOMP_FILTERS;
            self.emit_zero_field(line, SECCOMP_FILTERS)?;
            return Ok(true);
        }
        if (self.prefix_mask & PF_SECCOMP) == 0 && line.starts_with(SECCOMP) {
            self.prefix_mask |= PF_SECCOMP;
            self.emit_zero_field(line, SECCOMP)?;
            return Ok(true);
        }
        Ok(false)
    }

    // Check speculation groups.
    fn try_emit_patch_group(&mut self, line: &[u8]) -> Result<bool, Errno> {
        if (self.prefix_mask & PF_SPEC_SSB) == 0 && line.starts_with(SPEC_SSB) {
            self.prefix_mask |= PF_SPEC_SSB;
            let (head, value) = line.split_at(SPEC_SSB.len());
            self.obuf_write(head)?;
            return self.emit_patch_group_value(value, SPEC_SSB_PATCHES);
        }
        if (self.prefix_mask & PF_SPEC_SIB) == 0 && line.starts_with(SPEC_SIB) {
            self.prefix_mask |= PF_SPEC_SIB;
            let (head, value) = line.split_at(SPEC_SIB.len());
            self.obuf_write(head)?;
            return self.emit_patch_group_value(value, SPEC_SIB_PATCHES);
        }
        Ok(false)
    }

    // Zero-out numeric field while preserving whitespace after colon.
    fn emit_zero_field(&mut self, line: &[u8], field: &[u8]) -> Result<(), Errno> {
        let mut i = field.len();

        // skip whitespace after colon
        while i < line.len() {
            let b = line[i];
            if b == b' ' || b == b'\t' {
                i = i.checked_add(1).ok_or(Errno::EOVERFLOW)?;
            } else {
                break;
            }
        }

        let start = i;
        while i < line.len() && line[i].is_ascii_digit() {
            i = i.checked_add(1).ok_or(Errno::EOVERFLOW)?;
        }
        let end = i;

        // already "0" or empty -> passthrough
        let digits_len = end.checked_sub(start).ok_or(Errno::EOVERFLOW)?;
        if digits_len == 0 || (digits_len == 1 && line[start] == b'0') {
            self.obuf_write(line)?;
            return Ok(());
        }

        self.obuf_write(&line[..start])?;
        self.obuf_write(b"0\n")?;
        Ok(())
    }

    // One patch step:
    // - Scan patches not yet applied on this line.
    // - Pick earliest match (leftmost) in `value`.
    // - Write `left` then `repl`.
    // - Return `(bytes_consumed_from_value, new_applied_mask)`.
    fn apply_patch_step(
        &mut self,
        value: &[u8],
        patches: &[Patch],
        applied: PatchMask,
    ) -> Result<PatchStep, Errno> {
        if value.is_empty() || patches.is_empty() {
            return Ok(None);
        }

        let mut best_pos: Option<usize> = None;
        let mut best_idx: usize = 0;

        #[expect(clippy::cast_possible_truncation)]
        for (idx, p) in patches.iter().enumerate() {
            if ((applied >> (idx as u32)) & 1) != 0 {
                continue;
            } // Already applied on this line.
            if p.needle.is_empty() {
                continue;
            }
            if let Some(pos) = memmem::find(value, p.needle) {
                match best_pos {
                    None => {
                        best_pos = Some(pos);
                        best_idx = idx;
                    }
                    Some(cur) if pos < cur => {
                        best_pos = Some(pos);
                        best_idx = idx;
                    }
                    _ => {}
                }
            }
        }
        let Some(pos) = best_pos else {
            return Ok(None);
        };

        let (left, after_left) = value.split_at(pos);
        let needle_len = patches[best_idx].needle.len();
        let (_, rest) = after_left.split_at(needle_len);

        self.obuf_write(left)?;
        self.obuf_write(patches[best_idx].repl)?;

        let consumed = value
            .len()
            .checked_sub(rest.len())
            .ok_or(Errno::EOVERFLOW)?;
        if best_idx >= (u8::BITS as usize) {
            return Err(Errno::EOVERFLOW);
        }
        #[expect(clippy::cast_possible_truncation)]
        let bit: PatchMask = 1u8 << (best_idx as u32);
        let new_mask: PatchMask = applied | bit;

        Ok(Some((consumed, new_mask)))
    }

    // Apply patches to value; each patch at most once; prefix already written.
    fn emit_patch_group_value(
        &mut self,
        mut value: &[u8],
        patches: &[Patch],
    ) -> Result<bool, Errno> {
        let mut applied: PatchMask = 0;
        let mut any = false;

        loop {
            match self.apply_patch_step(value, patches, applied)? {
                None => {
                    self.obuf_write(value)?;
                    return Ok(any || !value.is_empty());
                }
                Some((consumed, new_mask)) => {
                    any = true;
                    if consumed > value.len() {
                        return Err(Errno::EOVERFLOW);
                    }
                    let (_, rest) = value.split_at(consumed);
                    value = rest;
                    applied = new_mask;
                }
            }
        }
    }

    fn emit_line<Fd: AsFd>(&mut self, _out: Fd, line: &[u8]) -> Result<(), Errno> {
        if self.try_emit_zero_field(line)? {
            return Ok(());
        }
        if self.try_emit_patch_group(line)? {
            return Ok(());
        }
        self.obuf_write(line)
    }

    // read entire file into inbuf (heap), then process as lines, single write at end
    fn run<S: AsFd, D: AsFd>(&mut self, src: S, dst: D) -> Result<(), Errno> {
        // Grow and read until EOF.
        loop {
            let cap = self.inbuf.capacity();
            let len = self.inbuf.len();
            let free = cap.checked_sub(len).ok_or(Errno::EOVERFLOW)?;
            if free == 0 {
                // add a small chunk to reduce realloc churn, avoid large jumps
                self.inbuf
                    .try_reserve(Self::GROW_STEP)
                    .map_err(|_| Errno::ENOMEM)?;
                continue;
            }

            let cur_len = len;
            let new_len = cur_len.checked_add(free).ok_or(Errno::EOVERFLOW)?;
            self.inbuf.resize(new_len, 0);

            // read into tail
            let tail = &mut self.inbuf[cur_len..new_len];
            let n = retry_on_eintr(|| read(&src, tail))?;
            if n == 0 {
                // EOF
                self.inbuf.truncate(cur_len);
                break;
            }
            let keep_len = cur_len.checked_add(n).ok_or(Errno::EOVERFLOW)?;
            self.inbuf.truncate(keep_len);
        }

        // pull input buffer out to avoid aliasing self while writing
        let inbuf = take(&mut self.inbuf);
        let mut start: usize = 0;

        // process lines by index (checked)
        loop {
            let slice = if start <= inbuf.len() {
                &inbuf[start..]
            } else {
                return Err(Errno::EOVERFLOW);
            };
            if let Some(nl_rel) = memchr(b'\n', slice) {
                let end_incl = start.checked_add(nl_rel).ok_or(Errno::EOVERFLOW)?;
                let line_end = end_incl.checked_add(1).ok_or(Errno::EOVERFLOW)?;
                if line_end > inbuf.len() {
                    return Err(Errno::EOVERFLOW);
                }
                let line = &inbuf[start..line_end];
                self.emit_line(&dst, line)?;
                start = line_end;
                continue;
            }
            // trailing partial line
            if start < inbuf.len() {
                let line = &inbuf[start..];
                self.emit_line(&dst, line)?;
            }
            break;
        }

        // single write(2) syscall
        self.flush_all(dst)
    }
}

//
// Public API
//
pub(crate) fn mask_proc_pid_status<S: AsFd, D: AsFd>(src: S, dst: D) -> Result<(), Errno> {
    ProcPidStatusMasker::new()?.run(src, dst)
}

#[cfg(test)]
mod tests {
    use nix::{
        fcntl::OFlag,
        unistd::{pipe2, write},
    };

    use super::*;

    fn run_mask(input: &[u8]) -> Result<Vec<u8>, Errno> {
        let (in_rd, in_wr) = pipe2(OFlag::O_CLOEXEC)?;
        let (out_rd, out_wr) = pipe2(OFlag::O_CLOEXEC)?;

        // Write input fully.
        let mut off = 0usize;
        while off < input.len() {
            match write(&in_wr, &input[off..]) {
                Ok(0) => break,
                Ok(n) => {
                    off = off.checked_add(n).ok_or(Errno::EOVERFLOW)?;
                }
                Err(e) => return Err(e),
            }
        }
        drop(in_wr);

        mask_proc_pid_status(&in_rd, &out_wr)?;
        drop(out_wr);

        // Read all output.
        let mut out = Vec::new();
        let mut buf = [0u8; 1024];
        loop {
            match retry_on_eintr(|| read(&out_rd, &mut buf)) {
                Ok(0) => break,
                Ok(n) => out.extend_from_slice(&buf[..n]),
                Err(e) => return Err(e),
            }
        }
        Ok(out)
    }

    //
    // Zeroing Paths
    //
    #[test]
    fn test_pps_mask_zero_simple_fields() {
        let input = b"TracerPid:\t123\nNoNewPrivs:\t1\nSeccomp:\t2\nSeccomp_filters:\t7\n";
        let out = run_mask(input).unwrap();
        let expected = b"TracerPid:\t0\nNoNewPrivs:\t0\nSeccomp:\t0\nSeccomp_filters:\t0\n";
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_pps_mask_preserve_whitespace() {
        let input = b"TracerPid:\t   456\nSeccomp:\t\t  2\n";
        let out = run_mask(input).unwrap();
        let expected = b"TracerPid:\t   0\nSeccomp:\t\t  0\n";
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_pps_mask_zero_already_zero_passthrough() {
        let input = b"TracerPid:\t0\nNoNewPrivs:\t0\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_pps_mask_suffix_after_digits_dropped() {
        let input = b"TracerPid:\t123 extra_garbage\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, b"TracerPid:\t0\n");
    }

    //
    // Speculation Patching
    //
    #[test]
    fn test_pps_mask_ssb_thread() {
        // "thread" removed, "mitigated" -> "vulnerable"
        let input = b"Speculation_Store_Bypass:   \t\tthread mitigated\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, b"Speculation_Store_Bypass:   \t\tvulnerable\n");
    }

    #[test]
    fn test_pps_mask_ssb_force() {
        // "force" removed, "mitigated" -> "vulnerable"
        let input = b"Speculation_Store_Bypass:\t   force mitigated\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, b"Speculation_Store_Bypass:\t   vulnerable\n");
    }

    #[test]
    fn test_pps_mask_ssb_thread_force() {
        // "thread" removed, "force" removed, "mitigated" -> "vulnerable"
        let input = b"Speculation_Store_Bypass:\tthread force mitigated\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, b"Speculation_Store_Bypass:\tvulnerable\n");
    }

    #[test]
    fn test_pps_mask_sib_conditional_force_disabled() {
        let input = b"SpeculationIndirectBranch:\t conditional force disabled\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, b"SpeculationIndirectBranch:\t always enabled\n");
    }

    #[test]
    fn test_pps_mask_sib_always_force_disabled() {
        let input = b"SpeculationIndirectBranch: \talways force disabled\n";
        let out = run_mask(input).unwrap();
        // generic collapse after removing "force"
        assert_eq!(&out, b"SpeculationIndirectBranch: \talways enabled\n");
    }

    #[test]
    fn test_pps_mask_sib_conditional_enabled() {
        let input = b"SpeculationIndirectBranch:\t  \tconditional enabled\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, b"SpeculationIndirectBranch:\t  \talways enabled\n");
    }

    //
    // Ordering Robustness (headers do not repeat, but order is not guaranteed)
    //
    #[test]
    fn test_pps_mask_reordered_lines_basic() {
        let input = concat!(
            "NoNewPrivs:\t1\n",
            "TracerPid:\t42\n",
            "Seccomp:\t2\n",
            "Seccomp_filters:\t3\n",
        )
        .as_bytes();
        let expected = concat!(
            "NoNewPrivs:\t0\n",
            "TracerPid:\t0\n",
            "Seccomp:\t0\n",
            "Seccomp_filters:\t0\n",
        )
        .as_bytes();
        let out = run_mask(input).unwrap();
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_pps_mask_reordered_lines_with_spec() {
        let input = concat!(
            "SpeculationIndirectBranch: \t  conditional enabled\n",
            "NoNewPrivs:\t1\n",
            "Speculation_Store_Bypass:   \t\t thread force mitigated\n",
            "Seccomp:\t  2\n",
            "TracerPid: \t42\n",
            "Seccomp_filters:  \t\t 3\n",
        )
        .as_bytes();
        let expected = concat!(
            "SpeculationIndirectBranch: \t  always enabled\n",
            "NoNewPrivs:\t0\n",
            "Speculation_Store_Bypass:   \t\t vulnerable\n",
            "Seccomp:\t  0\n",
            "TracerPid: \t0\n",
            "Seccomp_filters:  \t\t 0\n",
        )
        .as_bytes();
        let out = run_mask(input).unwrap();
        assert_eq!(&out, expected);
    }

    //
    // Whitespace Robustness
    //
    #[test]
    fn test_pps_mask_weird_whitespace_tabs_spaces() {
        let input = b"TracerPid:\t \t \t 999\nNoNewPrivs:\t\t\t3\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, b"TracerPid:\t \t \t 0\nNoNewPrivs:\t\t\t0\n");
    }

    #[test]
    fn test_pps_mask_no_digits_after_prefix() {
        // if no digits follow the field, passthrough unchanged
        let input = b"Seccomp:\t\t\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    //
    // Passthrough and limits
    //
    #[test]
    fn test_pps_mask_other_lines_unchanged() {
        let input = b"Name:\tcat\nState:\tS (sleeping)\nThreads:\t4\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_pps_mask_prefix_must_be_line_start() {
        let input = b"Name:\tSeccomp:\t2 (not a header)\n";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_pps_mask_long_line_zeroing() {
        let mut line = b"TracerPid:\t".to_vec();
        line.extend(std::iter::repeat(b'9').take(9000));
        line.push(b'\n');
        let out = run_mask(&line).unwrap();
        assert_eq!(out, b"TracerPid:\t0\n".to_vec());
    }

    #[test]
    fn test_pps_mask_long_nonmatching_passthrough() {
        let mut line = vec![b'A'; 10000];
        line.push(b'\n');
        let out = run_mask(&line).unwrap();
        assert_eq!(&out, &line);
    }

    #[test]
    fn test_pps_mask_combined_document_full() {
        let input = concat!(
            "Name:\tmyproc\n",
            "TracerPid:\t42\n",
            "Speculation_Store_Bypass:\tthread force mitigated\n",
            "NoNewPrivs:\t1\n",
            "SpeculationIndirectBranch:\t  conditional force disabled\n",
            "Seccomp:\t2\n",
            "Threads:\t5\n",
            "Seccomp_filters:\t3\n",
        )
        .as_bytes();

        let expected = concat!(
            "Name:\tmyproc\n",
            "TracerPid:\t0\n",
            "Speculation_Store_Bypass:\tvulnerable\n",
            "NoNewPrivs:\t0\n",
            "SpeculationIndirectBranch:\t  always enabled\n",
            "Seccomp:\t0\n",
            "Threads:\t5\n",
            "Seccomp_filters:\t0\n",
        )
        .as_bytes();

        let out = run_mask(input).unwrap();
        assert_eq!(&out, expected);
    }

    #[test]
    fn test_pps_mask_no_final_newline_passthrough_nonmatching() {
        let input = b"Name:\tno_nl_at_end";
        let out = run_mask(input).unwrap();
        assert_eq!(&out, input);
    }

    #[test]
    fn test_pps_mask_full() {
        // Build a long nonmatching line to force growth of the input buffer.
        let mut long = vec![b'X'; 4096];
        long.push(b'\n');

        let input = [
            b"Name:\tmyproc\n".as_ref(),
            b"State:\tS (sleeping)\n".as_ref(),
            b"TracerPid:\t   456\n".as_ref(), // zero with mixed ws
            b"NoNewPrivs:\t1\n".as_ref(),     // zero
            b"Speculation_Store_Bypass: \tthread force mitigated\n".as_ref(), // SSB patches
            b"SpeculationIndirectBranch:\t conditional force disabled\n".as_ref(), // SIB patches
            b"Seccomp:\t2\n".as_ref(),        // zero
            b"Threads:\t5\n".as_ref(),        // passthrough
            b"Seccomp_filters:\t3\n".as_ref(), // zero
            b"Note:\tSeccomp:\t2 (not a header)\n".as_ref(), // embedded token, passthrough
            &long,                            // long nonmatching line
            b"Name:\tno_nl_at_end".as_ref(),  // trailing partial line
        ]
        .concat();

        let expected = [
            b"Name:\tmyproc\n".as_ref(),
            b"State:\tS (sleeping)\n".as_ref(),
            b"TracerPid:\t   0\n".as_ref(),
            b"NoNewPrivs:\t0\n".as_ref(),
            b"Speculation_Store_Bypass: \tvulnerable\n".as_ref(), // generic collapse applied once
            b"SpeculationIndirectBranch:\t always enabled\n".as_ref(), // collapse after removing "force"
            b"Seccomp:\t0\n".as_ref(),
            b"Threads:\t5\n".as_ref(),
            b"Seccomp_filters:\t0\n".as_ref(),
            b"Note:\tSeccomp:\t2 (not a header)\n".as_ref(),
            &long,
            b"Name:\tno_nl_at_end".as_ref(),
        ]
        .concat();

        let out = run_mask(&input).unwrap();
        assert_eq!(&out, &expected);
    }
}
