use std::io::{BufWriter, Write};
use std::num::TryFromIntError;
use std::os::unix::io::OwnedFd;

use itertools::Itertools;
use log::debug;
use memmap2::Mmap;
use wayland_client::protocol::wl_shm::Format;

use crate::BufferInfo;

#[derive(thiserror::Error, Debug)]
pub enum Error {
    #[error("error encoding data into png")]
    Encoding(#[from] png::EncodingError),

    // This is only manually constructed when calling `Mmap::map`.
    #[error("error mapping file into memory")]
    Mmap(std::io::Error),

    // This is used as a catch-all for most output errors.
    #[error("failed to write to output file")]
    Io(#[from] std::io::Error),

    #[error("buffer format not implemented")]
    FormatNotImplemented,

    #[error("image size larger does not fit usize")]
    ImageToolarge(#[from] TryFromIntError),

    // I think this should never happen...?
    #[error("internal error flushing encoded PNG data")]
    FlushError(#[from] std::io::IntoInnerError<BufWriter<Vec<u8>>>),
}

/// Helper to re-write paramter orders in the "natural" order on little endian architectures..
#[cfg(target_endian = "little")]
macro_rules! normalise_endianness {
    (( *$a:ident, *$b:ident, *$c:ident)) => {
        (*$c, *$b, *$a)
    };
    (( *$a:ident, *$b:ident, *$c:ident, *$d:ident)) => {
        (*$d, *$c, *$b, *$a)
    };
}

// No-op, exists as a counterpart a macro of the same name on little endian architectures.
#[cfg(target_endian = "big")]
macro_rules! normalise_endianness {
    (( *$a:ident, *$b:ident, *$c:ident)) => {
        (*$a, *$b, *$c)
    };
    (( *$a:ident, *$b:ident, *$c:ident, *$d:ident)) => {
        (*$a, *$b, *$c, *$d)
    };
}

// Returns a tuple with: input pixel bpp, output bit depth, and color type.
#[inline]
fn pixel_info_for_format(format: Format) -> Result<(usize, png::BitDepth, png::ColorType), Error> {
    match format {
        Format::Xrgb8888 | Format::Xbgr8888 => Ok((4, png::BitDepth::Eight, png::ColorType::Rgb)),
        Format::Argb8888 | Format::Abgr8888 => Ok((4, png::BitDepth::Eight, png::ColorType::Rgba)),
        Format::Xrgb2101010 | Format::Xbgr2101010 => {
            Ok((4, png::BitDepth::Sixteen, png::ColorType::Rgb))
        }
        _ => Err(Error::FormatNotImplemented),
    }
}

/// Encodes the raw image in `raw_file` as a PNG.
///
/// `raw_file` will be `mmap`ed to memory for reading. If the underlying file is backed by memory,
/// this mapping is a no-op.
#[allow(clippy::cast_possible_truncation)]
pub fn write_buffer_to_png_file(
    raw_file: &OwnedFd,
    buffer_info: &BufferInfo,
    // file: &File,
) -> Result<Vec<u8>, Error> {
    let width = usize::try_from(buffer_info.width)?;
    let stride = usize::try_from(buffer_info.stride)?;
    let height = usize::try_from(buffer_info.height)?;

    // SAFETY: we don't map any of this memory into concrete types and only read raw bytes.
    let mmap = unsafe { Mmap::map(raw_file) }.map_err(Error::Mmap)?;

    let (pixel_bpp, png_depth, png_color_type) = pixel_info_for_format(buffer_info.format)?;

    // Iterator over image bytes. Drops all bytes past end of stride.
    let bytes_iter = mmap
        .chunks(stride)
        .flat_map(|row| &row[..(width * pixel_bpp)]);

    let out_channels = if png_color_type == png::ColorType::Rgba {
        4
    } else {
        3
    };
    let out_pixel_size = if png_depth == png::BitDepth::Eight {
        out_channels
    } else {
        2 * out_channels
    };

    // A new in-memory buffer for image data converted to RGB.
    let mut image_data = vec![0u8; width * height * out_pixel_size]; // XXX: Can fail if too big

    debug!("Encoding image with source format {:?}", buffer_info.format);
    // Convert format into to RGB (by copying into the new buffer).
    if buffer_info.format == Format::Xrgb8888 {
        bytes_iter
            .tuples()
            .zip(image_data.iter_mut().tuples())
            .for_each(|((x, r, g, b), (r2, g2, b2))| {
                (_, *r2, *g2, *b2) = normalise_endianness!((*x, *r, *g, *b));
            });
    } else if buffer_info.format == Format::Xbgr8888 {
        bytes_iter
            .tuples()
            .zip(image_data.iter_mut().tuples())
            .for_each(|((x, b, g, r), (r2, g2, b2))| {
                (*r2, *g2, *b2, _) = normalise_endianness!((*r, *g, *b, *x));
            });
    } else if buffer_info.format == Format::Argb8888 {
        bytes_iter
            .tuples()
            .zip(image_data.iter_mut().tuples())
            .for_each(|((a, r, g, b), (r2, g2, b2, a2))| {
                (*r2, *g2, *b2, *a2) = normalise_endianness!((*r, *g, *b, *a));
            });
    } else if buffer_info.format == Format::Abgr8888 {
        bytes_iter
            .tuples()
            .zip(image_data.iter_mut().tuples())
            .for_each(|((a, b, g, r), (r2, g2, b2, a2))| {
                (*a2, *r2, *g2, *b2) = normalise_endianness!((*a, *r, *g, *b));
            });
    } else if buffer_info.format == Format::Xrgb2101010 {
        // at BitDepth::Sixteen, write_image_data is given big endian u16 data
        bytes_iter
            .tuples()
            .zip(image_data.iter_mut().tuples())
            .for_each(|((x0, x1, x2, x3), (r2, r1, g2, g1, b2, b1))| {
                let v = u32::from(*x3) << 24
                    | u32::from(*x2) << 16
                    | u32::from(*x1) << 8
                    | u32::from(*x0);
                let r10 = (v & 0x3ff0_0000) >> 20;
                let g10 = (v & 0xffc00) >> 10;
                let b10 = v & 0x3ff;
                let r = r10 << 6 | r10 >> 4;
                let g = g10 << 6 | g10 >> 4;
                let b = b10 << 6 | b10 >> 4;
                (*r2, *r1, *g2, *g1, *b2, *b1) = (
                    (r >> 8) as u8,
                    (r & 0xff) as u8,
                    (g >> 8) as u8,
                    (g & 0xff) as u8,
                    (b >> 8) as u8,
                    (b & 0xff) as u8,
                );
            });
    } else if buffer_info.format == Format::Xbgr2101010 {
        bytes_iter
            .tuples()
            .zip(image_data.iter_mut().tuples())
            .for_each(|((x0, x1, x2, x3), (r2, r1, g2, g1, b2, b1))| {
                let v = u32::from(*x3) << 24
                    | u32::from(*x2) << 16
                    | u32::from(*x1) << 8
                    | u32::from(*x0);
                let b10 = (v & 0x3ff0_0000) >> 20;
                let g10 = (v & 0xffc00) >> 10;
                let r10 = v & 0x3ff;
                let r = r10 << 6 | r10 >> 4;
                let g = g10 << 6 | g10 >> 4;
                let b = b10 << 6 | b10 >> 4;
                (*r2, *r1, *g2, *g1, *b2, *b1) = (
                    (r >> 8) as u8,
                    (r & 0xff) as u8,
                    (g >> 8) as u8,
                    (g & 0xff) as u8,
                    (b >> 8) as u8,
                    (b & 0xff) as u8,
                );
            });
    } else {
        return Err(Error::FormatNotImplemented);
    }

    let output = Vec::new();
    // TODO: can I std::io::Cursor and reduce system calls? This needs to be benchmarked.

    let mut file_writer = BufWriter::new(output);
    let mut encoder = png::Encoder::new(&mut file_writer, buffer_info.width, buffer_info.height);
    encoder.set_depth(png_depth);
    encoder.set_color(png_color_type);
    let mut writer = encoder.write_header()?;
    writer.write_image_data(&image_data)?;
    writer.finish()?;
    file_writer.flush()?;

    Ok(file_writer.into_inner()?)
}
