package jp.sourceforge.acerola3d.apng;

import javax.imageio.*;
import javax.imageio.spi.*;
import javax.imageio.stream.*;
import java.util.*;
import java.awt.image.DataBuffer;
import java.io.IOException;
import java.util.zip.DataFormatException;
import java.io.DataInputStream;
import java.awt.image.BufferedImage;
import java.awt.Rectangle;
import java.awt.Point;
import javax.imageio.metadata.IIOMetadata;
import java.awt.image.DataBufferByte;
import java.awt.image.WritableRaster;
import java.awt.image.Raster;
import java.awt.color.ColorSpace;

public class APNGImageReader extends ImageReader {
    //PNGファイル・シグネチャ
    static byte[] pngSignature = {
        (byte)0x89,(byte)0x50,(byte)0x4e,(byte)0x47,
        (byte)0x0d,(byte)0x0a,(byte)0x1a,(byte)0x0a
    };

    ImageInputStream stream = null;

    int width;
    int height;
    int bitDepth;
    int colorType;
    int compressionMethod;
    int filterMethod;
    int interlaceMethod;
    byte[][] palette;
    HashMap<Integer,Integer> alpha;
    Frame currentFrame;
    ArrayList<Frame> frames = new ArrayList<Frame>();

    int chunkSequence=0;
    int num_frames;
    int num_plays;

    boolean gotAllData = false;
    APNGMetadata metadata = null;


    public APNGImageReader(ImageReaderSpi originatingProvider) {
        super(originatingProvider);
    }

    public void setInput(Object input, boolean isStreamable) {
        super.setInput(input, isStreamable);
        if (input == null) {
            this.stream = null;
            return;
        }
        if (input instanceof ImageInputStream) {
            this.stream = (ImageInputStream)input;
        } else {
            throw new IllegalArgumentException("bad input");
        }
    }

    public int getNumImages(boolean allowSearch) throws IIOException {
        readAll();
        return num_frames;
    }

    private void checkIndex(int imageIndex) {
        if (imageIndex < 0 || imageIndex >= num_frames) {
            throw new IndexOutOfBoundsException("bad index");
        }
    }

    public int getWidth(int imageIndex) throws IIOException {
        checkIndex(imageIndex);
        readAll();
        return width;
    }

    public int getHeight(int imageIndex) throws IIOException {
        checkIndex(imageIndex);
        readAll();
        return height;
    }

    public Iterator<ImageTypeSpecifier> getImageTypes(int imageIndex) throws IIOException {
        checkIndex(imageIndex);
        readAll();

        ArrayList<ImageTypeSpecifier> al = new ArrayList<ImageTypeSpecifier>();
        if (colorType==0) {
            // TODO 自信ねぇ
            al.add(ImageTypeSpecifier.createGrayscale(8,DataBuffer.TYPE_BYTE,false));
        } else if (colorType==2) {
            // TODO 適当
            al.add(ImageTypeSpecifier.createGrayscale(8,DataBuffer.TYPE_BYTE,false));
        } else if (colorType==3) {
            // TODO 適当
            al.add(ImageTypeSpecifier.createGrayscale(8,DataBuffer.TYPE_BYTE,false));
        } else if (colorType==4) {
            // TODO 適当
            al.add(ImageTypeSpecifier.createGrayscale(8,DataBuffer.TYPE_BYTE,false));
            int datatype = DataBuffer.TYPE_BYTE;
            ColorSpace rgb =
                ColorSpace.getInstance(ColorSpace.CS_sRGB);
            int[] bandOffsets = new int[3];
            bandOffsets[0] = 0;
            bandOffsets[1] = 1;
            bandOffsets[2] = 2;
            al.add(ImageTypeSpecifier.createInterleaved(rgb,bandOffsets,datatype,false,false));
        } else if (colorType==6) {
            // TODO 適当
            int datatype = DataBuffer.TYPE_BYTE;
            ColorSpace rgb =
                ColorSpace.getInstance(ColorSpace.CS_sRGB);
            int[] bandOffsets = new int[4];
            bandOffsets[0] = 0;
            bandOffsets[1] = 1;
            bandOffsets[2] = 2;
            bandOffsets[3] = 3;
            al.add(ImageTypeSpecifier.createInterleaved(rgb,bandOffsets,datatype,true,false));
        } 
        return al.iterator();
    }

    public void readAll() throws IIOException{
        if (gotAllData) {
            return;
        }
        gotAllData = true;

        if (stream == null) {
            throw new IllegalStateException("No input stream");
        }

        // Read 89504E470D0A1A0AH from the stream
        byte[] signature = new byte[8];
        try {
            stream.readFully(signature);
        } catch (IOException e) {
            throw new IIOException("Error reading signature", e);
        }

        for (int i=0;i<pngSignature.length;i++)
            if (signature[i]!=pngSignature[i])
                throw new IIOException("Bad file signature!");

        try {
            ArrayList<Chunk> chunks = new ArrayList<Chunk>();
            while (true) {
                Chunk c = new Chunk(stream);
                chunks.add(c);
                //System.out.printf("%03d:%s%n",chunks.size(),c.toString());
                if (c.typeStr.equals("IEND"))
                    break;
            }
            stream.close();

            for (Chunk c : chunks) {
                if (c.typeStr.equals("IHDR"))
                    process_IHDR(c);
                else if (c.typeStr.equals("cHRM"))
                    process_cHRM(c);
                else if (c.typeStr.equals("gAMA"))
                    process_gAMA(c);
                else if (c.typeStr.equals("iCCP"))
                    process_iCCP(c);
                else if (c.typeStr.equals("sBIT"))
                    process_sBIT(c);
                else if (c.typeStr.equals("sRGB"))
                    process_sRGB(c);
                else if (c.typeStr.equals("PLTE"))
                    process_PLTE(c);
                else if (c.typeStr.equals("bKGD"))
                    process_bKGD(c);
                else if (c.typeStr.equals("hIST"))
                    process_hIST(c);
                else if (c.typeStr.equals("tRNS"))
                    process_tRNS(c);
                else if (c.typeStr.equals("pHYs"))
                    process_pHYs(c);
                else if (c.typeStr.equals("sPLT"))
                    process_sPLT(c);
                else if (c.typeStr.equals("IDAT"))
                    process_IDAT(c);
                else if (c.typeStr.equals("IEND"))
                    process_IEND(c);
                else if (c.typeStr.equals("tIME"))
                    process_tIME(c);
                else if (c.typeStr.equals("iTXt"))
                    process_iTXt(c);
                else if (c.typeStr.equals("tEXt"))
                    process_tEXt(c);
                else if (c.typeStr.equals("zTXt"))
                    process_zTXt(c);
                else if (c.typeStr.equals("acTL"))//APNG
                    process_acTL(c);
                else if (c.typeStr.equals("fcTL"))//APNG
                    process_fcTL(c);
                else if (c.typeStr.equals("fdAT"))//APNG
                    process_fdAT(c);
                else
                    process_OTHER(c);
            }
            metadata = new APNGMetadata();
        } catch (Exception e) {
            throw new IIOException("Exception reading data",e);
        }
    }

    void process_IHDR(Chunk c) throws IOException {
        DataInputStream dis = c.getDataInputStream();
        width = dis.readInt();
        height = dis.readInt();
        bitDepth = (int)dis.readByte();
        colorType = (int)dis.readByte();
        compressionMethod = (int)dis.readByte();
        filterMethod = (int)dis.readByte();
        interlaceMethod = (int)dis.readByte();
    }

    void process_cHRM(Chunk c){System.out.println("cHRM: not implemented.");}
    void process_gAMA(Chunk c){System.out.println("gAMA: not implemented.");}
    void process_iCCP(Chunk c){System.out.println("iCCP: not implemented.");}
    void process_sBIT(Chunk c){System.out.println("sBIT: not implemented.");}
    void process_sRGB(Chunk c){System.out.println("sRGB: not implemented.");}

    void process_PLTE(Chunk c) {
        palette = new byte[c.data.length/3][];
        for (int i=0;i<c.data.length/3;i++) {
            byte[] rgb = {c.data[3*i],c.data[3*i+1],c.data[3*i+2]};
            palette[i] = rgb;
        }
    }

    void process_bKGD(Chunk c){System.out.println("bKGD: not implemented.");}
    void process_hIST(Chunk c){System.out.println("hIST: not implemented.");}

    void process_tRNS(Chunk c) {
        alpha = new HashMap<Integer,Integer>();
        if (colorType==0) {
            //TODO テストしてない
            for (int i=0;i<c.data.length;i=i+2) {
                byte b0 = c.data[i+0];
                byte b1 = c.data[i+1];
                int color = 0xff & b0; color = color<<8;
                color = color | 0xff & b1;
                alpha.put(color,0);
            }
        } else if (colorType==2) {
            //TODO テストしてない
            for (int i=0;i<c.data.length;i=i+6) {
                byte b0 = c.data[i+0];
                byte b1 = c.data[i+1];
                byte b2 = c.data[i+2];
                byte b3 = c.data[i+3];
                byte b4 = c.data[i+4];
                byte b5 = c.data[i+5];
                int color = 0xff & b0; color = color<<8;
                color = color | 0xff & b1; color = color<<8;
                color = color | 0xff & b2; color = color<<8;
                color = color | 0xff & b3; color = color<<8;
                color = color | 0xff & b4; color = color<<8;
                color = color | 0xff & b5;
                alpha.put(color,0);
            }
        } else if (colorType==3) {
            for (int i=0;i<c.data.length;i++)
                alpha.put(i,(int)c.data[i]);
        } else {
            System.out.println("tRNS: ERROR!");
            System.out.println("colorType="+colorType+" must not have tRNS.");
        }
    }

    void process_pHYs(Chunk c){System.out.println("pHYs: not implemented.");}
    void process_sPLT(Chunk c){System.out.println("sPLT: not implemented.");}

    void process_IDAT(Chunk c) throws IOException {
        if (currentFrame==null) {
            currentFrame = new Frame(this,null,-1);
        }
        currentFrame.addChunk(c,chunkSequence);
    }

    void process_IEND(Chunk c) throws DataFormatException {
        currentFrame.createImage();
        frames.add(currentFrame);
    }

    void process_tIME(Chunk c){System.out.println("tIME: not implemented.");}
    void process_iTXt(Chunk c){System.out.println("iTXt: not implemented.");}
    void process_tEXt(Chunk c){System.out.println("tEXt: not implemented.");}
    void process_zTXt(Chunk c){System.out.println("zTXt: not implemented.");}

    //以下APNGのチャンク
    void process_acTL(Chunk c) throws IOException {
        DataInputStream dis = c.getDataInputStream();
        num_frames = dis.readInt();
        if (num_frames==0) System.out.println("num_frams must not 0.");
        num_plays = dis.readInt();
    }
    void process_fcTL(Chunk c) throws IOException, DataFormatException {
        if (currentFrame!=null) {
            currentFrame.createImage();
            frames.add(currentFrame);
        }
        currentFrame = new Frame(this,c,chunkSequence++);
    }
    void process_fdAT(Chunk c) {
        currentFrame.addChunk(c,chunkSequence++);
    }
    

    void process_OTHER(Chunk c) {
        System.out.println(c.typeStr+" : not implemented.");
    }

    public BufferedImage read(int imageIndex, ImageReadParam param) throws IIOException {
        readAll();
        return frames.get(imageIndex).getImage();
    }
    public BufferedImage read_BAK(int imageIndex, ImageReadParam param) throws IIOException {

        // Compute initial source region, clip against destination later
        Rectangle sourceRegion = getSourceRegion(param, width, height);

        // Set everything to default values
        int sourceXSubsampling = 1;
        int sourceYSubsampling = 1;
        int[] sourceBands = null;
        int[] destinationBands = null;
        Point destinationOffset = new Point(0, 0);

        // Get values from the ImageReadParam, if any
        if (param != null) {
            sourceXSubsampling = param.getSourceXSubsampling();
            sourceYSubsampling = param.getSourceYSubsampling();
            sourceBands = param.getSourceBands();
            destinationBands = param.getDestinationBands();
            destinationOffset = param.getDestinationOffset();
        }

        // Get the specified detination image or create a new one
        BufferedImage dst = getDestination(param,
                                           getImageTypes(imageIndex),
                                           width, height);

        // Enure band settings from param are compatible with images
        //colorType = 1:パレット使用||2:カラー||4:アルファ有
        int inputBands=0;
        if (colorType==0) inputBands = 1;
        else if (colorType==2) inputBands = 3;
        else if (colorType==3) inputBands = 1;
        else if (colorType==4) inputBands = 2;
        else if (colorType==6) inputBands = 4;
System.out.println("GAHA:"+dst.getSampleModel().getNumBands());
        checkReadParamBandSettings(param, inputBands,
                                   dst.getSampleModel().getNumBands());

        int[] bandOffsets = new int[inputBands];
        for (int i = 0; i < inputBands; i++) {
            bandOffsets[i] = i;
        }
        int bytesPerRow = width*inputBands;
        DataBufferByte rowDB = new DataBufferByte(bytesPerRow);
        WritableRaster rowRas =
            Raster.createInterleavedRaster(rowDB,
                                           width, 1, bytesPerRow,
                                           inputBands, bandOffsets,
                                           new Point(0, 0));
        byte[] rowBuf = rowDB.getData();

        // Create an int[] that can a single pixel
        int[] pixel = rowRas.getPixel(0, 0, (int[])null);

        WritableRaster imRas = dst.getWritableTile(0, 0);
        int dstMinX = imRas.getMinX();
        int dstMaxX = dstMinX + imRas.getWidth() - 1;
        int dstMinY = imRas.getMinY();
        int dstMaxY = dstMinY + imRas.getHeight() - 1;

        // Create a child raster exposing only the desired source bands
        if (sourceBands != null) {
            rowRas = rowRas.createWritableChild(0, 0,
                                                width, 1,
                                                0, 0,
                                                sourceBands);
        }

        // Create a child raster exposing only the desired dest bands
        if (destinationBands != null) {
            imRas = imRas.createWritableChild(0, 0,
                                              imRas.getWidth(),
                                              imRas.getHeight(),
                                              0, 0,
                                              destinationBands);
        }


        for (int srcY = 0; srcY < height; srcY++) {
            // Read the row
            try {
                stream.readFully(rowBuf);
            } catch (IOException e) {
                throw new IIOException("Error reading line " + srcY, e);
            }

            // Reject rows that lie outside the source region,
            // or which aren't part of the subsampling
            if ((srcY < sourceRegion.y) ||
                (srcY >= sourceRegion.y + sourceRegion.height) ||
                (((srcY - sourceRegion.y) %
                  sourceYSubsampling) != 0)) {
                continue;
            }

            // Determine where the row will go in the destination
            int dstY = destinationOffset.y +
                (srcY - sourceRegion.y)/sourceYSubsampling;
            if (dstY < dstMinY) {
                continue; // The row is above imRas
            }
            if (dstY > dstMaxY) {
                break; // We're done with the image
            }

            // Copy each (subsampled) source pixel into imRas
            for (int srcX = sourceRegion.x;
                 srcX < sourceRegion.x + sourceRegion.width;
                 srcX++) {
                if (((srcX - sourceRegion.x) % sourceXSubsampling) != 0) {
                    continue;
                }
                int dstX = destinationOffset.x +
                    (srcX - sourceRegion.x)/sourceXSubsampling;
                if (dstX < dstMinX) {
                    continue;  // The pixel is to the left of imRas
                }
                if (dstX > dstMaxX) {
                    break; // We're done with the row
                }

                // Copy the pixel, sub-banding is done automatically
                rowRas.getPixel(srcX, 0, pixel);
                imRas.setPixel(dstX, dstY, pixel);
            }
        }
        return dst;
        /*
        // Create an int[] that can hold a row's worth of pixels
        int[] pixels = rowRas.getPixels(0, 0, width, 1, (int[])null);

        // Clip against the left and right edges of the destination image
        int srcMinX =
            Math.max(sourceRegion.x,
                         dstMinX - destinationOffset.x + sourceRegion.x);
          int srcMaxX =
            Math.min(sourceRegion.x + sourceRegion.width - 1,
                     dstMaxX - destinationOffset.x + sourceRegion.x);
        int dstX = destinationOffset.x + (srcMinX - sourceRegion.x);
        int w = srcMaxX - srcMinX + 1;
        rowRas.getPixels(srcMinX, 0, w, 1, pixels);
        imRas.setPixels(dstX, dstY, w, 1, pixels);
        */
    }

    /*
    public BufferedImage read(int imageIndex, ImageReadParam param)
        throws IOException {
        // Clear any previous abort request
        boolean aborted = false;
        clearAbortRequested();

        // Inform IIOReadProgressListeners of the start of the image
        processImageStarted(imageIndex);

        // Compute xMin, yMin, xSkip, ySkip from the ImageReadParam
        // ...

        // Create a suitable image
        BufferedImage theImage = new BufferedImage(...);

        // Compute factors for use in reporting percentages
        int pixelsPerRow = (width - xMin + xSkip - 1)/xSkip;
        int rows = (height - yMin + ySkip - 1)/ySkip;
        long pixelsDecoded = 0L;
        long totalPixels = rows*pixelsPerRow;

        for (int y = yMin; y < height; y += yskip) {
            // Decode a (subsampled) scanline of the image
            // ...

            // Update the percentage estimate
            // This may be done only every few rows if desired
            pixelsDecoded += pixelsPerRow;
            processImageProgress(100.0F*pixelsDecoded/totalPixels);

            // Check for an asynchronous abort request
            if (abortRequested()) {
                aborted = true;
                break;
            }
        }

        // Handle the end of decoding
        if (aborted) {
            processImageAborted();
        } else {
            processImageComplete(imageIndex);
        }

        // If the read was aborted, we still return a partially decoded image
        return theImage;
    }
    */

    public IIOMetadata getStreamMetadata()
        throws IIOException {
        return null;
    }

    public IIOMetadata getImageMetadata(int imageIndex) throws IIOException {
        readAll();
        return metadata;
    }
}
