//
// Syd: rock-solid application kernel
// src/asm.rs: Assembly instruction decoder
//
// Copyright (c) 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

// SAFETY: This module has been liberated from unsafe code!
#![forbid(unsafe_code)]

use std::{
    fmt::Write as FmtWrite,
    fs::File,
    io::{BufWriter, Write},
    process::{Command, Stdio},
};

use iced_x86::{Decoder, DecoderOptions, FastFormatter, Formatter, IntelFormatter};
use libseccomp::ScmpArch;
use nix::{
    errno::Errno,
    unistd::{mkstemp, unlink},
};
use raki::{Decode, Isa};
use serde::{ser::SerializeSeq, Serialize, Serializer};
use yaxpeax_arch::{Arch, Decoder as ArmDecoder, Reader, U8Reader};
use yaxpeax_arm::{armv7::ARMv7, armv8::a64::ARMv8};

use crate::err::err2no;

/// Structure representing a disassembled instruction.
#[derive(Clone, Debug)]
pub struct Instruction {
    /// Operation as a string, if available (e.g. may be zero padding).
    pub op: Option<String>,
    /// Hexadecimal encoded instruction bytes
    pub hex: String,
}

impl Serialize for Instruction {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut seq = serializer.serialize_seq(Some(2))?;

        seq.serialize_element(&self.op)?;
        seq.serialize_element(&(self.hex.len() / 2))?;
        seq.serialize_element(&self.hex)?;

        seq.end()
    }
}

// A small enum to hold either a Intel or a Fast formatter for iced-x86.
// This lets us switch between them at runtime without using trait
// objects, since `FastFormatter` doesn't implement `Formatter`.
#[allow(clippy::large_enum_variant)]
enum X86Formatter {
    /// Uses the Intel formatter from iced-x86.
    Intel(IntelFormatter),
    /// Uses the specialized Fast formatter (faster, but fewer options).
    Fast(FastFormatter),
}

impl X86Formatter {
    /// Format a single instruction into the given output string.
    ///
    /// `instr` is the iced_x86 instruction to format,
    /// `output` is the `String` to which the formatted text is appended.
    fn format(&mut self, instr: &iced_x86::Instruction, output: &mut String) {
        match self {
            Self::Intel(fmt) => fmt.format(instr, output),
            Self::Fast(fmt) => fmt.format(instr, output),
        }
    }
}

/// Disassemble raw machine code into a vector of instructions.
#[allow(clippy::arithmetic_side_effects)]
pub fn disasm(
    machine_code: &[u8],
    arch: ScmpArch,
    ip: u64,
    fast_fmt: bool,
    verbose: bool,
) -> Result<Vec<Instruction>, Errno> {
    match arch {
        ScmpArch::X8664 | ScmpArch::X86 | ScmpArch::X32 => {
            // Use native X86 decoder.
            return disasm_x86(machine_code, arch, ip, fast_fmt, DecoderOptions::NONE);
        }
        ScmpArch::Aarch64 | ScmpArch::Arm => {
            // Use native ARM decoder.
            return disasm_arm(machine_code, arch);
        }
        ScmpArch::Riscv64 => {
            // Use native Riscv64 decoder.
            return disasm_riscv64(machine_code);
        }
        // or else fallback to objdump.
        // TODO: Add llvm-objdump support!
        _ => {}
    }

    // Map architecture to objdump-compatible string
    let arch = scmp_arch2objdump(&arch);

    // Create a temporary file using nix.
    let (fd, path) = mkstemp("/tmp/syd_objdumpXXXXXX")?;
    let mut file = BufWriter::new(File::from(fd));

    // Write the machine code to the temporary file.
    file.write_all(machine_code).map_err(|err| err2no(&err))?;

    // Close the file, ensure writes persist.
    drop(file);

    // Call objdump with the appropriate arguments.
    let mut command = Command::new("objdump");
    if verbose {
        command.stderr(Stdio::inherit());
    }
    let output = command
        .env("LC_ALL", "C")
        .env("LANG", "C")
        .arg("-D")
        .arg("-b")
        .arg("binary")
        .arg("-m")
        .arg(arch)
        .arg(&path)
        .output()
        .map_err(|err| err2no(&err))?;

    // Clean up the temporary file.
    unlink(&path)?;

    let output = std::str::from_utf8(&output.stdout).or(Err(Errno::EINVAL))?;

    // Parse objdump output
    let mut instructions = Vec::new();
    for line in output.lines() {
        // Check if the line starts with a valid address-like pattern (hexadecimal followed by ':')
        let trimmed = line.trim();
        if let Some(colon_pos) = trimmed.find(':') {
            // Validate the address part (before ':')
            let address_part = &trimmed[..colon_pos];
            if !address_part.chars().all(|c| c.is_ascii_hexdigit()) {
                continue;
            }

            // The part after the colon contains the rest of the disassembled instruction.
            let rest = &trimmed[colon_pos + 1..];
            let parts: Vec<_> = rest.split_whitespace().collect();

            if !parts.is_empty() {
                // Calculate instruction size from hex.
                let hex_end = parts
                    .iter()
                    .position(|&s| s.chars().any(|c| !c.is_ascii_hexdigit()))
                    .unwrap_or(parts.len());
                let hex = parts[..hex_end].join("");

                // Extract operation (mnemonic + operands).
                let op = parts[hex_end..].join(" ");

                // Add the instruction to the list.
                instructions.push(Instruction {
                    hex,
                    op: if op.is_empty() { None } else { Some(op) },
                });
            }
        }
    }

    Ok(instructions)
}

/// Disassemble raw machine code into a vector of instructions.
///
/// `arch` _must_ be one of `ScmpArch::X8664`, `ScmpArch::X86`, or `ScmpArch::X32`,
/// or else this function will return `Err(Errno::ENOSYS)`.
pub fn disasm_x86(
    machine_code: &[u8],
    arch: ScmpArch,
    ip: u64,
    fast_fmt: bool,
    opts: u32,
) -> Result<Vec<Instruction>, Errno> {
    // Determine bitness.
    let bitness = match arch {
        ScmpArch::X8664 => 64,
        ScmpArch::X86 => 32,
        ScmpArch::X32 => 32,
        _ => return Err(Errno::ENOSYS),
    };

    // Create an iced-x86 decoder with the given IP.
    let mut decoder = Decoder::with_ip(bitness, machine_code, ip, opts);

    // Select our runtime formatter, storing it in the enum.
    let mut formatter = if fast_fmt {
        X86Formatter::Fast(FastFormatter::new())
    } else {
        X86Formatter::Intel(IntelFormatter::new())
    };

    // We'll store the final instructions in this vector.
    let mut instructions = Vec::new();

    // Reusable iced_x86 Instruction to avoid extra allocations.
    let mut iced_instr = iced_x86::Instruction::default();

    // Decode until no bytes remain or we hit invalid data.
    while decoder.can_decode() {
        // Decode into `iced_instr`.
        // If it's invalid, we push a “null” instruction.
        decoder.decode_out(&mut iced_instr);

        if iced_instr.is_invalid() {
            // We attempt to extract the failing byte, if any
            let fault_pos = decoder.position().saturating_sub(1);
            let null_hex = if fault_pos < machine_code.len() {
                format!("{:02x}", machine_code[fault_pos])
            } else {
                String::new()
            };

            instructions.push(Instruction {
                hex: null_hex,
                op: Some("null".to_string()),
            });
            continue;
        }

        // Format the instruction.
        let mut text = String::new();
        formatter.format(&iced_instr, &mut text);

        // Instruction size in bytes
        let instr_len = iced_instr.len();
        let end_pos = decoder.position();
        let start_pos = end_pos.saturating_sub(instr_len);

        // Extract the raw bytes,
        // and convert to a hex string (e.g. "0f1f8400000000")
        let raw_bytes = &machine_code[start_pos..end_pos];
        let hex_str = raw_bytes
            .iter()
            .map(|b| format!("{b:02x}"))
            .collect::<Vec<_>>()
            .join("");

        // Push our final instruction struct.
        instructions.push(Instruction {
            hex: hex_str,
            // e.g. "syscall", "nopl 0x0(%rax,%rax,1)", etc.
            op: Some(text),
        });
    }

    Ok(instructions)
}

/// Disassemble raw ARM machine code into a vector of instructions.
///
/// - `arch` must be either `ScmpArch::Arm` (ARMv7) or `ScmpArch::Aarch64` (ARMv8),
///   or this returns `Err(Errno::ENOSYS)`.
///
/// If a decode error occurs, we push a pseudo "null" instruction for the single
/// offending byte and skip it.
fn disasm_arm(machine_code: &[u8], arch: ScmpArch) -> Result<Vec<Instruction>, Errno> {
    match arch {
        ScmpArch::Arm => disasm_armv7(machine_code),
        ScmpArch::Aarch64 => disasm_armv8(machine_code),
        _ => Err(Errno::ENOSYS),
    }
}

/// Helper to decode ARMv7 instructions from `machine_code` using `yaxpeax_arm::armv7::ARMv7`.
/// Returns a vector of `Instruction` with `.hex` and `.op` fields.
fn disasm_armv7(machine_code: &[u8]) -> Result<Vec<Instruction>, Errno> {
    let mut instructions = Vec::new();

    let decoder = <ARMv7 as Arch>::Decoder::default();
    let mut reader = U8Reader::new(machine_code);

    // We track how many bytes we've consumed so far with `old_offset`.
    let mut old_offset = <U8Reader<'_> as yaxpeax_arch::Reader<u32, u8>>::total_offset(&mut reader);

    loop {
        let decode_res = decoder.decode(&mut reader);
        match decode_res {
            Ok(inst) => {
                // Successfully decoded an instruction.
                let new_offset: u32 = <U8Reader<'_> as Reader<u32, u8>>::total_offset(&mut reader);

                // Grab the actual bytes from the input slice.
                let raw_bytes = &machine_code[old_offset as usize..new_offset as usize];
                let mut hex_str = String::new();
                for b in raw_bytes {
                    write!(&mut hex_str, "{b:02x}").or(Err(Errno::ENOMEM))?;
                }

                // Convert instruction to a display string.
                let op_str = inst.to_string();

                instructions.push(Instruction {
                    hex: hex_str,
                    op: Some(op_str),
                });

                // Update offset for the next iteration.
                old_offset = new_offset;
            }
            Err(_decode_err) => {
                // On decode error, we push a "null" for one offending byte if any remain.
                #[allow(clippy::arithmetic_side_effects)]
                if (old_offset as usize) < machine_code.len() {
                    let b = machine_code[old_offset as usize];
                    instructions.push(Instruction {
                        hex: format!("{b:02x}"),
                        op: Some("null".to_string()),
                    });
                    // Manually consume one byte from the reader to move on.
                    // ignoring the actual result
                    let _ = <U8Reader<'_> as Reader<u32, u8>>::next(&mut reader);
                    old_offset += 1;
                } else {
                    // No more data left to consume, so break out.
                    break;
                }
            }
        }

        // If we've consumed everything, break out.
        if (old_offset as usize) >= machine_code.len() {
            break;
        }
    }

    Ok(instructions)
}

/// Helper to decode ARMv8 (AArch64) instructions from `machine_code`
/// using `yaxpeax_arm::armv8::a64::ARMv8`.
fn disasm_armv8(machine_code: &[u8]) -> Result<Vec<Instruction>, Errno> {
    let mut instructions = Vec::new();

    let decoder = <ARMv8 as Arch>::Decoder::default();
    let mut reader = U8Reader::new(machine_code);

    // We track how many bytes we've consumed so far with `old_offset`.
    let mut old_offset = <U8Reader<'_> as yaxpeax_arch::Reader<u64, u8>>::total_offset(&mut reader);

    #[allow(clippy::arithmetic_side_effects)]
    #[allow(clippy::cast_possible_truncation)]
    loop {
        let decode_res = decoder.decode(&mut reader);
        match decode_res {
            Ok(inst) => {
                // Successfully decoded an instruction.
                let new_offset: u64 = <U8Reader<'_> as Reader<u64, u8>>::total_offset(&mut reader);

                // Grab the actual bytes from the input slice.
                let raw_bytes = &machine_code[old_offset as usize..new_offset as usize];
                let mut hex_str = String::new();
                for b in raw_bytes {
                    write!(&mut hex_str, "{b:02x}").or(Err(Errno::ENOMEM))?;
                }

                // Convert instruction to a display string.
                let op_str = inst.to_string();

                instructions.push(Instruction {
                    hex: hex_str,
                    op: Some(op_str),
                });

                // Update offset for the next iteration.
                old_offset = new_offset;
            }
            Err(_decode_err) => {
                // On decode error, we push a "null" for one offending byte if any remain.
                if (old_offset as usize) < machine_code.len() {
                    let b = machine_code[old_offset as usize];
                    instructions.push(Instruction {
                        hex: format!("{b:02x}"),
                        op: Some("null".to_string()),
                    });
                    // Manually consume one byte from the reader to move on.
                    // ignoring the actual result
                    let _ = <U8Reader<'_> as Reader<u64, u8>>::next(&mut reader);
                    old_offset += 1;
                } else {
                    break;
                }
            }
        }

        // If we've consumed everything, break out.
        if (old_offset as usize) >= machine_code.len() {
            break;
        }
    }

    Ok(instructions)
}

/// Disassemble raw RISC-V (RV64) machine code into a vector of instructions.
///
/// Decoding uses the `raki` crate (`raki::Decode`) in `Isa::Rv64` mode.
pub fn disasm_riscv64(machine_code: &[u8]) -> Result<Vec<Instruction>, Errno> {
    let mut instructions = Vec::new();
    let mut offset = 0usize;

    // Loop until we’ve consumed all bytes.
    #[allow(clippy::arithmetic_side_effects)]
    while offset < machine_code.len() {
        let remaining = machine_code.len() - offset;

        // 1) If we have at least 2 bytes, try decode as 16-bit (compressed).
        if remaining >= 2 {
            let half_word_bytes = &machine_code[offset..offset + 2];
            let half_word = u16::from_le_bytes([half_word_bytes[0], half_word_bytes[1]]);

            match half_word.decode(Isa::Rv64) {
                Ok(inst) => {
                    // Decoded a valid 16-bit instruction.
                    let mut hex_str = String::new();
                    for b in half_word_bytes {
                        write!(&mut hex_str, "{b:02x}").or(Err(Errno::ENOMEM))?;
                    }

                    instructions.push(Instruction {
                        hex: hex_str,
                        op: Some(inst.to_string()),
                    });

                    offset += 2;
                    continue; // next iteration
                }
                Err(_) => {
                    // 2) If 16-bit failed and we have at least 4 bytes, try 32-bit.
                    if remaining >= 4 {
                        let word_bytes = &machine_code[offset..offset + 4];
                        let word = u32::from_le_bytes([
                            word_bytes[0],
                            word_bytes[1],
                            word_bytes[2],
                            word_bytes[3],
                        ]);

                        match word.decode(Isa::Rv64) {
                            Ok(inst) => {
                                // Valid 32-bit instruction.
                                let mut hex_str = String::new();
                                for b in word_bytes {
                                    write!(&mut hex_str, "{b:02x}").or(Err(Errno::ENOMEM))?;
                                }

                                instructions.push(Instruction {
                                    hex: hex_str,
                                    op: Some(inst.to_string()),
                                });

                                offset += 4;
                                continue;
                            }
                            Err(_) => {
                                // Both 16-bit and 32-bit decode failed.
                                // => “null” for just the first byte, skip 1.
                                let b = machine_code[offset];
                                instructions.push(Instruction {
                                    hex: format!("{b:02x}"),
                                    op: Some("null".to_string()),
                                });
                                offset += 1;
                                continue;
                            }
                        }
                    } else {
                        // Not enough bytes to try 32-bit => “null” for first byte.
                        let b = machine_code[offset];
                        instructions.push(Instruction {
                            hex: format!("{b:02x}"),
                            op: Some("null".to_string()),
                        });
                        offset += 1;
                        continue;
                    }
                }
            }
        } else {
            // 3) If fewer than 2 bytes remain, we can’t decode 16-bit => “null” each leftover byte.
            let b = machine_code[offset];
            instructions.push(Instruction {
                hex: format!("{b:02x}"),
                op: Some("null".to_string()),
            });
            offset += 1;
        }
    }

    Ok(instructions)
}

/// Convert ScmpArch to objdump architecture name.
/// Map ScmpArch to objdump architecture strings.
pub const fn scmp_arch2objdump(arch: &ScmpArch) -> &'static str {
    match arch {
        ScmpArch::X8664 => "i386:x86-64",
        ScmpArch::X86 => "i386",
        ScmpArch::Arm => "arm",
        ScmpArch::Aarch64 => "aarch64",
        ScmpArch::Loongarch64 => "loongarch64",
        ScmpArch::M68k => "m68k",
        ScmpArch::Mips => "mips",
        ScmpArch::Mips64 => "mips64",
        ScmpArch::Riscv64 => "riscv:rv64",
        ScmpArch::Ppc64 => "powerpc:common64",
        ScmpArch::Ppc64Le => "powerpc:common64",
        ScmpArch::S390X => "s390:64",
        ScmpArch::Sheb => "sheb",
        ScmpArch::Sh => "sh",
        _ => "unknown",
    }
}
