/*
 * This software is distributed under following license based on modified BSD
 * style license.
 * ----------------------------------------------------------------------
 * 
 * Copyright 2009 The Nimbus2 Project. All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer. 
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE NIMBUS PROJECT ``AS IS'' AND ANY EXPRESS
 * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
 * NO EVENT SHALL THE NIMBUS PROJECT OR CONTRIBUTORS BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 * 
 * The views and conclusions contained in the software and documentation are
 * those of the authors and should not be interpreted as representing official
 * policies, either expressed or implied, of the Nimbus2 Project.
 */
package jp.ossc.nimbus.service.aop.interceptor.servlet;

import java.io.*;
import java.util.zip.*;

import javax.servlet.*;
import javax.servlet.http.*;

import jp.ossc.nimbus.service.aop.*;

/**
 * HTTPX|ẌksC^[Zv^B<p>
 *
 * @author M.Takata
 */
public class HttpServletResponseDeflateInterceptorService
 extends ServletFilterInterceptorService
 implements HttpServletResponseDeflateInterceptorServiceMBean{
    
    private static final long serialVersionUID = -8811812672782874906L;
    
    /** wb_[ : Content-Encoding */
    private static final String HEADER_ACCEPT_ENCODING = "Accept-Encoding";
    /** wb_[ : Content-Encoding */
    private static final String HEADER_CONTENT_ENCODING = "Content-Encoding";
    /** Content-Encoding : deflate */
    private static final String CONTENT_ENCODING_DEFLATE = "deflate";
    /** Content-Encoding : gzip */
    private static final String CONTENT_ENCODING_GZIP = "gzip";
    /** Content-Encoding : x-zip */
    private static final String CONTENT_ENCODING_X_GZIP = "x-gzip";
    /** Content-Encoding : identity */
    private static final String CONTENT_ENCODING_IDENTITY = "identity";
    /** Content-Encoding : identity */
    private static final String CONTENT_ENCODING_ALL = "*";
    /** ftHgGR[fBO */
    private static final String DEFAULT_ENC = "ISO_8859-1";
    
    private String enabledContentTypes[];
    private String disabledContentTypes[];
    private int deflateLength = -1;
    
    public void setEnabledContentTypes(String[] contentTypes){
        enabledContentTypes = contentTypes;
    }
    
    public String[] getEnabledContentTypes(){
        return enabledContentTypes;
    }
    
    public void setDisabledContentTypes(String[] contentTypes){
        disabledContentTypes = contentTypes;
    }
    
    public String[] getDisabledContentTypes(){
        return disabledContentTypes;
    }
    
    public void setDeflateLength(int length){
        deflateLength = length;
    }
    
    public int getDeflateLength(){
        return deflateLength;
    }
    
    /**
     * X|Xksbp[ŃbvāÃC^[Zv^ĂяoB<p>
     * T[rXJnĂȂꍇ́AɎ̃C^[Zv^ĂяoB<br>
     *
     * @param context ĂяõReLXg
     * @param chain ̃C^[Zv^Ăяo߂̃`F[
     * @return Ăяoʂ̖߂l
     * @exception Throwable ĂяoŗOꍇA܂͂̃C^[Zv^ŔCӂ̗OꍇBAA{Ăяo鏈throwȂRuntimeExceptionȊO̗OthrowĂAĂяoɂ͓`dȂB
     */
    public Object invokeFilter(
        ServletFilterInvocationContext context,
        InterceptorChain chain
    ) throws Throwable{
        if(getState() == State.STARTED){
            final ServletRequest request = context.getServletRequest();
            boolean isWrap = false;
            if(request instanceof HttpServletRequest){
                final String acceptEncoding = ((HttpServletRequest)request)
                    .getHeader(HEADER_ACCEPT_ENCODING);
                if(acceptEncoding != null){
                    context.setServletResponse(
                        new DeflateHttpServletResponseWrapper(
                            (HttpServletResponse)context.getServletResponse(),
                            acceptEncoding,
                            enabledContentTypes,
                            disabledContentTypes,
                            deflateLength
                        )
                    );
                    isWrap = true;
                }
            }
            try{
                return chain.invokeNext(context);
            }finally{
                if(isWrap){
                    ServletResponse response
                         = context.getServletResponse();
                    if(response instanceof DeflateHttpServletResponseWrapper){
                        ((DeflateHttpServletResponseWrapper)response).flushBuffer();
                        context.setServletResponse(
                            ((DeflateHttpServletResponseWrapper)response)
                                .getResponse()
                        );
                    }else{
                        while((response instanceof ServletResponseWrapper)
                            && !(response instanceof DeflateHttpServletResponseWrapper)){
                            response = ((ServletResponseWrapper)response).getResponse();
                        }
                        if(response instanceof DeflateHttpServletResponseWrapper){
                            ((DeflateHttpServletResponseWrapper)response).flushBuffer();
                        }
                    }
                }
            }
        }else{
            return chain.invokeNext(context);
        }
    }
    
    private static class DeflateHttpServletResponseWrapper
     extends HttpServletResponseWrapper{
        
        private String acceptEncoding;
        private String[] enabledContentTypes;
        private String[] disabledContentTypes;
        private ServletOutputStream sos;
        private PrintWriter pw;
        private int deflateLength = -1;
        
        public DeflateHttpServletResponseWrapper(
            HttpServletResponse response,
            String acceptEncoding,
            String[] enabledContentTypes,
            String[] disabledContentTypes,
            int deflateLength
        ){
            super(response);
            this.acceptEncoding = acceptEncoding;
            this.enabledContentTypes = enabledContentTypes;
            this.disabledContentTypes = disabledContentTypes;
        }
        
        public ServletOutputStream getOutputStream() throws IOException{
            
            if(sos != null){
                return sos;
            }
            
            if(disabledContentTypes != null && disabledContentTypes.length != 0){
                final String contentType = getContentType();
                boolean disable = false;
                for(int i = 0; i < disabledContentTypes.length; i++){
                    if(disabledContentTypes[i].equalsIgnoreCase(contentType)){
                        disable = true;
                        break;
                    }
                }
                if(disable){
                    sos = super.getOutputStream();
                    return sos;
                }
            }
            
            if(enabledContentTypes != null && enabledContentTypes.length != 0){
                final String contentType = getContentType();
                boolean enable = false;
                for(int i = 0; i < enabledContentTypes.length; i++){
                    if(enabledContentTypes[i].equalsIgnoreCase(contentType)){
                        enable = true;
                        break;
                    }
                }
                if(!enable){
                    sos = super.getOutputStream();
                    return sos;
                }
            }
            
            sos = new DeflateServletOutputStreamWrapper(
                (HttpServletResponse)getResponse(),
                acceptEncoding,
                getCharacterEncoding(),
                deflateLength
            );
            return sos;
        }
        
        public PrintWriter getWriter() throws IOException{
            if(pw == null){
                String charEncoding = getCharacterEncoding();
                pw = new PrintWriter(
                    new OutputStreamWriter(
                        getOutputStream(),
                        charEncoding == null ? DEFAULT_ENC : charEncoding
                    )
                );
            }
            return pw;
        }
        
        public void flushBuffer() throws IOException{
            if(sos instanceof DeflateServletOutputStreamWrapper){
                ((DeflateServletOutputStreamWrapper)sos).flushBuffer();
                setContentLength(
                    ((DeflateServletOutputStreamWrapper)sos).getWriteLength()
                );
            }
            super.flushBuffer();
        }
    }
    
    private static class DeflateServletOutputStreamWrapper
     extends ServletOutputStream{
        private HttpServletResponse response;
        private ByteArrayOutputStream baos;
        private PrintStream ps;
        private ServletOutputStream sos;
        private String acceptEncoding;
        private int deflateLength = -1;
        private int writeLength;
        public DeflateServletOutputStreamWrapper(
            HttpServletResponse response,
            String acceptEncoding,
            String charEncoding,
            int deflateLength
        ) throws IOException{
            super();
            this.response = response;
            baos = new ByteArrayOutputStream();
            this.acceptEncoding = acceptEncoding;
            this.deflateLength = deflateLength;
            ps = new PrintStream(
                baos,
                true,
                charEncoding == null ? DEFAULT_ENC : charEncoding
            );
        }
        
        public void write(int b) throws IOException{
            baos.write(b);
        }
        public void write(byte[] b) throws IOException{
            baos.write(b);
        }
        public void write(byte[] b, int off, int len) throws IOException{
            baos.write(b, off, len);
        }
        
        public void print(String s) throws IOException{
            ps.print(s);
        }
        public void print(boolean b) throws IOException{
            ps.print(b);
        }
        public void print(char c) throws IOException{
            ps.print(c);
        }
        public void print(int i) throws IOException{
            ps.print(i);
        }
        public void print(long l) throws IOException{
            ps.print(l);
        }
        public void print(float f) throws IOException{
            ps.print(f);
        }
        public void print(double d) throws IOException{
            ps.print(d);
        }
        public void println() throws IOException{
            ps.println();
        }
        public void println(String s) throws IOException{
            ps.println(s);
        }
        public void println(boolean b) throws IOException{
            ps.println(b);
        }
        public void println(char c) throws IOException{
            ps.println(c);
        }
        public void println(int i) throws IOException{
            ps.println(i);
        }
        public void println(long l) throws IOException{
            ps.println(l);
        }
        public void println(float f) throws IOException{
            ps.println(f);
        }
        public void println(double d) throws IOException{
            ps.println(d);
        }
        
        private String getAppropriateEncoding(String encoding){
            if(encoding.indexOf(';') == -1){
                if(encoding.indexOf(CONTENT_ENCODING_ALL) != -1
                     || encoding.indexOf(CONTENT_ENCODING_GZIP) != -1
                     || encoding.indexOf(CONTENT_ENCODING_X_GZIP) != -1){
                    return CONTENT_ENCODING_GZIP;
                }else if(encoding.indexOf(CONTENT_ENCODING_DEFLATE) != -1){
                    return CONTENT_ENCODING_DEFLATE;
                }else{
                    return CONTENT_ENCODING_IDENTITY;
                }
            }
            double currentQValue = 0.0d;
            String result = CONTENT_ENCODING_IDENTITY;
            final String[] encodes = encoding.split(",");
            for(int i = 0; i < encodes.length; i++){
                String encode = encodes[i].trim();;
                if(encode.startsWith(CONTENT_ENCODING_DEFLATE)
                    || encode.startsWith(CONTENT_ENCODING_GZIP)
                    || encode.startsWith(CONTENT_ENCODING_X_GZIP)
                    || encode.startsWith(CONTENT_ENCODING_ALL)
                    || encode.startsWith(CONTENT_ENCODING_IDENTITY)
                ){
                    int index = encode.indexOf(';');
                    double qValue = 1.0d;
                    if(index != -1){
                        String qValueStr = encode.substring(index + 1);
                        encode = encode.substring(0, index).trim();
                        index = qValueStr.indexOf('=');
                        if(index != -1){
                            qValueStr = qValueStr.substring(index + 1);
                            try{
                                qValue = Double.parseDouble(qValueStr);
                            }catch(NumberFormatException e){
                            }
                        }
                        if(qValue == 0.0d
                             && CONTENT_ENCODING_IDENTITY.equals(encode)
                             && CONTENT_ENCODING_IDENTITY.equals(result)
                        ){
                            result = null;
                        }
                        if(qValue > currentQValue){
                            if(CONTENT_ENCODING_ALL.equals(encode)){
                                result = CONTENT_ENCODING_GZIP;
                            }else if(CONTENT_ENCODING_X_GZIP.equals(encode)){
                                result = CONTENT_ENCODING_GZIP;
                            }else{
                                result = encode;
                            }
                        }
                    }
                }else{
                    continue;
                }
            }
            return result;
        }
        
        public void flushBuffer() throws IOException{
            ps.flush();
            byte[] bytes = baos.toByteArray();
            if(bytes != null && bytes.length != 0){
                if(bytes.length >= deflateLength){
                    baos.reset();
                    final String encoding
                         = getAppropriateEncoding(acceptEncoding);
                    if(encoding == null
                         || CONTENT_ENCODING_IDENTITY.equals(encoding)){
                        // Ȃ
                    }else{
                        DeflaterOutputStream dos = null;
                        if(CONTENT_ENCODING_DEFLATE.equals(encoding)){
                            // deflatek
                            dos = new DeflaterOutputStream(baos);
                        }else if(CONTENT_ENCODING_GZIP.equals(encoding)){
                            dos = new GZIPOutputStream(baos);
                        }
                        response.setHeader(
                            HEADER_CONTENT_ENCODING,
                            encoding
                        );
                        dos.write(bytes);
                        dos.flush();
                        dos.finish();
                        dos.close();
                        bytes = baos.toByteArray();
                        baos.reset();
                    }
                }
                if(sos == null){
                    sos = response.getOutputStream();
                }
                response.setContentLength(bytes.length);
                sos.write(bytes);
                sos.flush();
                writeLength += bytes.length;
            }
        }
        
        public int getWriteLength(){
            return writeLength;
        }
        
        public void close() throws IOException{
            flush();
            ps.close();
            baos.close();
            if(sos != null){
                sos.close();
            }
            super.close();
        }
    }
}