/*
 * This software is distributed under following license based on modified BSD
 * style license.
 * ----------------------------------------------------------------------
 * 
 * Copyright 2009 The Nimbus2 Project. All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer. 
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE NIMBUS PROJECT ``AS IS'' AND ANY EXPRESS
 * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
 * NO EVENT SHALL THE NIMBUS PROJECT OR CONTRIBUTORS BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 * 
 * The views and conclusions contained in the software and documentation are
 * those of the authors and should not be interpreted as representing official
 * policies, either expressed or implied, of the Nimbus2 Project.
 */
package jp.ossc.nimbus.service.aop.interceptor.servlet;

import java.util.*;
import java.util.regex.*;

import javax.servlet.*;
import javax.servlet.http.*;

import jp.ossc.nimbus.service.aop.*;

/**
 * HTTPNGXg`FbNC^[Zv^B<p>
 *
 * @author M.Takata
 */
public class HttpServletRequestCheckInterceptorService
 extends ServletFilterInterceptorService
 implements HttpServletRequestCheckInterceptorServiceMBean{
    
    private static final long serialVersionUID = -8791823240259229953L;
    
    protected int maxContentLength = -1;
    protected int minContentLength = -1;
    protected boolean isAllowNullContentType = true;
    protected String[] validContentTypes;
    protected String[] invalidContentTypes;
    protected boolean isAllowNullCharacterEncoding = true;
    protected String[] validCharacterEncodings;
    protected String[] invalidCharacterEncodings;
    protected boolean isAllowNullLocale = true;
    protected String[] validLocales;
    protected Pattern[] validLocalePatterns;
    protected String[] validProtocols;
    protected String[] validRemoteAddrs;
    protected Pattern[] validRemoteAddrPatterns;
    protected String[] validRemoteHosts;
    protected Pattern[] validRemoteHostPatterns;
    protected int[] validRemotePorts;
    protected String[] validSchemata;
    protected String[] validServerNames;
    protected Pattern[] validServerNamePatterns;
    protected String[] validMethods;
    protected String[] invalidMethods;
    protected Properties headerEquals;
    protected Map<String, Pattern> headerEqualsMap;
    protected int errorStatus = HttpServletResponse.SC_BAD_REQUEST;
    protected boolean isThrowOnError;
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setMaxContentLength(int max){
        maxContentLength = max;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public int getMaxContentLength(){
        return maxContentLength;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setMinContentLength(int min){
        minContentLength = min;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public int getMinContentLength(){
        return minContentLength;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setAllowNullContentType(boolean isAllow){
        isAllowNullContentType = isAllow;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public boolean isAllowNullContentType(){
        return isAllowNullContentType;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidContentTypes(String[] types){
        validContentTypes = types;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidContentTypes(){
        return validContentTypes;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setInvalidContentTypes(String[] types){
        invalidContentTypes = types;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getInvalidContentTypes(){
        return invalidContentTypes;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setAllowNullCharacterEncoding(boolean isAllow){
        isAllowNullCharacterEncoding = isAllow;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public boolean isAllowNullCharacterEncoding(){
        return isAllowNullCharacterEncoding;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidCharacterEncodings(String[] encodings){
        validCharacterEncodings = encodings;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidCharacterEncodings(){
        return validCharacterEncodings;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setInvalidCharacterEncodings(String[] encodings){
        invalidCharacterEncodings = encodings;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getInvalidCharacterEncodings(){
        return invalidCharacterEncodings;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setAllowNullLocale(boolean isAllow){
        isAllowNullLocale = isAllow;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public boolean isAllowNullLocale(){
        return isAllowNullLocale;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidLocales(String[] locales){
        validLocales = locales;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidLocales(){
        return validLocales;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidProtocols(String[] protocols){
        validProtocols = protocols;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidProtocols(){
        return validProtocols;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidRemoteAddrs(String[] addrs){
        validRemoteAddrs = addrs;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidRemoteAddrs(){
        return validRemoteAddrs;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidRemoteHosts(String[] hosts){
        validRemoteHosts = hosts;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidRemoteHosts(){
        return validRemoteHosts;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidRemotePorts(int[] ports){
        validRemotePorts = ports;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public int[] getValidRemotePorts(){
        return validRemotePorts;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidSchemata(String[] schemata){
        validSchemata = schemata;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidSchemata(){
        return validSchemata;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidServerNames(String[] names){
        validServerNames = names;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidServerNames(){
        return validServerNames;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setValidMethods(String[] methods){
        validMethods = methods;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getValidMethods(){
        return validMethods;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setInvalidMethods(String[] methods){
        invalidMethods = methods;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public String[] getInvalidMethods(){
        return invalidMethods;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setHeaderEquals(Properties cond){
        headerEquals = cond;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public Properties getHeaderEquals(){
        return headerEquals;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setErrorStatus(int status){
        errorStatus = status;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public int getErrorStatus(){
        return errorStatus;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public void setThrowOnError(boolean isThrow){
        isThrowOnError = isThrow;
    }
    
    // HttpServletRequestCheckInterceptorServiceMBeanJavaDoc
    public boolean isThrowOnError(){
        return isThrowOnError;
    }
    
    /**
     * T[rX̊JnsB<p>
     *
     * @exception Exception T[rX̊JnɎsꍇ
     */
    public void startService() throws Exception{
        if(validLocales != null && validLocales.length != 0){
            validLocalePatterns = new Pattern[validLocales.length];
            for(int i = 0; i < validLocales.length; i++){
                validLocalePatterns[i] = Pattern.compile(validLocales[i]);
            }
        }
        
        if(validRemoteAddrs != null && validRemoteAddrs.length != 0){
            validRemoteAddrPatterns = new Pattern[validRemoteAddrs.length];
            for(int i = 0; i < validRemoteAddrs.length; i++){
                validRemoteAddrPatterns[i]
                     = Pattern.compile(validRemoteAddrs[i]);
            }
        }
        
        if(validRemoteHosts != null && validRemoteHosts.length != 0){
            validRemoteHostPatterns = new Pattern[validRemoteHosts.length];
            for(int i = 0; i < validRemoteHosts.length; i++){
                validRemoteHostPatterns[i]
                     = Pattern.compile(validRemoteHosts[i]);
            }
        }
        
        if(validServerNames != null && validServerNames.length != 0){
            validServerNamePatterns = new Pattern[validServerNames.length];
            for(int i = 0; i < validServerNames.length; i++){
                validServerNamePatterns[i]
                     = Pattern.compile(validServerNames[i]);
            }
        }
        
        if(headerEquals != null){
            headerEqualsMap = new HashMap<String, Pattern>();
            for(Map.Entry<Object, Object> entry : headerEquals.entrySet()){
                final String name = (String)entry.getKey();
                headerEqualsMap.put(
                    name,
                    Pattern.compile((String)entry.getValue())
                );
            }
        }
    }
    
    /**
     * HTTPNGXg̃`FbNāÃC^[Zv^ĂяoB<p>
     * T[rXJnĂȂꍇ́AɎ̃C^[Zv^ĂяoB<br>
     *
     * @param context ĂяõReLXg
     * @param chain ̃C^[Zv^Ăяo߂̃`F[
     * @return Ăяoʂ̖߂l
     * @exception Throwable ĂяoŗOꍇA܂͂̃C^[Zv^ŔCӂ̗OꍇBAA{Ăяo鏈throwȂRuntimeExceptionȊO̗OthrowĂAĂяoɂ͓`dȂB
     */
    @SuppressWarnings("unchecked")
    public Object invokeFilter(
        ServletFilterInvocationContext context,
        InterceptorChain chain
    ) throws Throwable{
        final ServletRequest request = context.getServletRequest();
        final ServletResponse response = context.getServletResponse();
        if(getState() == State.STARTED){
            final int contentLength = request.getContentLength();
            if(maxContentLength >= 0
                && contentLength >= maxContentLength){
                return fail(
                    request,
                    response,
                    "MaxContentLength is " + maxContentLength
                        + " : " + contentLength
                );
            }
            if(minContentLength >= 0
                && contentLength <= minContentLength){
                return fail(
                    request,
                    response,
                    "MinContentLength is " + minContentLength
                        + " : " + contentLength
                );
            }
            
            final String contentType = request.getContentType();
            if(contentType == null){
                if(!isAllowNullContentType){
                    return fail(
                        request,
                        response,
                        "ContentType is null."
                    );
                }
            }else{
                if(validContentTypes != null){
                    boolean success = false;
                    for(int i = 0; i < validContentTypes.length; i++){
                        if(contentType.equalsIgnoreCase(validContentTypes[i])){
                            success = true;
                            break;
                        }
                    }
                    if(!success){
                        return fail(
                            request,
                            response,
                            "ContentType is invalid : " + contentType
                        );
                    }
                }
                if(invalidContentTypes != null){
                    for(int i = 0; i < invalidContentTypes.length; i++){
                        if(contentType.equals(invalidContentTypes[i])){
                            return fail(
                                request,
                                response,
                                "ContentType is invalid : " + contentType
                            );
                        }
                    }
                }
            }
            
            final String encoding = request.getCharacterEncoding();
            if(encoding == null){
                if(!isAllowNullCharacterEncoding){
                    return fail(
                        request,
                        response,
                        "CharacterEncoding is null."
                    );
                }
            }else{
                if(validCharacterEncodings != null){
                    boolean success = false;
                    for(int i = 0; i < validCharacterEncodings.length; i++){
                        if(encoding.equals(validCharacterEncodings[i])){
                            success = true;
                            break;
                        }
                    }
                    if(!success){
                        return fail(
                            request,
                            response,
                            "CharacterEncoding is invalid : " + encoding
                        );
                    }
                }
                if(invalidCharacterEncodings != null){
                    for(int i = 0; i < invalidCharacterEncodings.length; i++){
                        if(encoding.equals(invalidCharacterEncodings[i])){
                            return fail(
                                request,
                                response,
                                "CharacterEncoding is invalid : " + encoding
                            );
                        }
                    }
                }
            }
            
            final Enumeration<Locale> locales = (Enumeration<Locale>)request.getLocales();
            if(!locales.hasMoreElements()){
                if(!isAllowNullLocale){
                    return fail(
                        request,
                        response,
                        "Locale is null."
                    );
                }
            }else{
                if(validLocales != null){
                    boolean success = false;
                    while(locales.hasMoreElements()){
                        final String locale
                             = locales.nextElement().toString();
                        for(int i = 0; i < validLocales.length; i++){
                            final Matcher m
                                 = validLocalePatterns[i].matcher(locale);
                            if(m.matches()){
                                success = true;
                                break;
                            }
                        }
                        if(success){
                            break;
                        }
                    }
                    if(!success){
                        return fail(
                            request,
                            response,
                            "Locale is invalid : " + locales
                        );
                    }
                }
            }
            
            if(validProtocols != null){
                final String protocol = request.getProtocol();
                boolean success = false;
                for(int i = 0; i < validProtocols.length; i++){
                    if(protocol.equals(validProtocols[i])){
                        success = true;
                        break;
                    }
                }
                if(!success){
                    return fail(
                        request,
                        response,
                        "Protocol is invalid : " + protocol
                    );
                }
            }
            
            if(validRemoteAddrs != null){
                final String addr = request.getRemoteAddr();
                boolean success = false;
                for(int i = 0; i < validRemoteAddrs.length; i++){
                    final Matcher m
                         = validRemoteAddrPatterns[i].matcher(addr);
                    if(m.matches()){
                        success = true;
                        break;
                    }
                }
                if(!success){
                    return fail(
                        request,
                        response,
                        "Remote address is invalid : " + addr
                    );
                }
            }
            
            if(validRemoteHosts != null){
                final String host = request.getRemoteHost();
                boolean success = false;
                for(int i = 0; i < validRemoteHosts.length; i++){
                    final Matcher m
                         = validRemoteHostPatterns[i].matcher(host);
                    if(m.matches()){
                        success = true;
                        break;
                    }
                }
                if(!success){
                    return fail(
                        request,
                        response,
                        "Remote host is invalid : " + host
                    );
                }
            }
            
            if(validRemotePorts != null){
                final int port = request.getRemotePort();
                boolean success = false;
                for(int i = 0; i < validRemotePorts.length; i++){
                    if(port == validRemotePorts[i]){
                        success = true;
                        break;
                    }
                }
                if(!success){
                    return fail(
                        request,
                        response,
                        "Remote port is invalid : " + port
                    );
                }
            }
            
            if(validSchemata != null){
                final String scheme = request.getScheme();
                boolean success = false;
                for(int i = 0; i < validSchemata.length; i++){
                    if(scheme.equals(validSchemata[i])){
                        success = true;
                        break;
                    }
                }
                if(!success){
                    return fail(
                        request,
                        response,
                        "Scheme is invalid : " + scheme
                    );
                }
            }
            
            if(validServerNames != null){
                final String serverName = request.getServerName();
                boolean success = false;
                for(int i = 0; i < validServerNames.length; i++){
                    final Matcher m
                         = validServerNamePatterns[i].matcher(serverName);
                    if(m.matches()){
                        success = true;
                        break;
                    }
                }
                if(!success){
                    return fail(
                        request,
                        response,
                        "Server name is invalid : " + serverName
                    );
                }
            }
            
            if(request instanceof HttpServletRequest){
                final HttpServletRequest httpReq = (HttpServletRequest)request;
                
                final String method = httpReq.getMethod();
                if(validMethods != null && method != null){
                    boolean success = false;
                    for(int i = 0; i < validMethods.length; i++){
                        if(method.equals(validMethods[i])){
                            success = true;
                            break;
                        }
                    }
                    if(!success){
                        return fail(
                            request,
                            response,
                            "Method is invalid : " + method
                        );
                    }
                }
                if(invalidMethods != null && method != null){
                    for(int i = 0; i < invalidMethods.length; i++){
                        if(method.equals(invalidMethods[i])){
                            return fail(
                                request,
                                response,
                                "Method is invalid : " + method
                            );
                        }
                    }
                }
                
                if(headerEqualsMap != null && headerEqualsMap.size() != 0){
                    for(Map.Entry<String, Pattern> entry : headerEqualsMap.entrySet()){
                        final String name = entry.getKey();
                        final String value = httpReq.getHeader(name);
                        if(value == null){
                            return fail(
                                request,
                                response,
                                "Header " + name + " is invalid : " + value
                            );
                        }
                        final Pattern p = entry.getValue();
                        final Matcher m = p.matcher(value);
                        if(!m.matches()){
                            return fail(
                                request,
                                response,
                                "Header " + name + " is invalid : " + value
                            );
                        }
                    }
                }
            }
        }
        return chain.invokeNext(context);
    }
    
    protected Object fail(
        ServletRequest request,
        ServletResponse response,
        String message
    ) throws HttpServletRequestCheckException{
        if(isThrowOnError){
            throw new HttpServletRequestCheckException(message);
        }
        if(response instanceof HttpServletResponse){
            ((HttpServletResponse)response).setStatus(errorStatus);
        }
        return null;
    }
}
