/******************************************************************************
 * Copyright (C) 2006 Tetsuya Kimata <kimata@acapulco.dyndns.org>
 *
 * All rights reserved.
 *
 * This software is provided 'as-is', without any express or implied
 * warranty.  In no event will the authors be held liable for any
 * damages arising from the use of this software.
 *
 * Permission is granted to anyone to use this software for any
 * purpose, including commercial applications, and to alter it and
 * redistribute it freely, subject to the following restrictions:
 *
 * 1. The origin of this software must not be misrepresented; you must
 *    not claim that you wrote the original software. If you use this
 *    software in a product, an acknowledgment in the product
 *    documentation would be appreciated but is not required.
 *
 * 2. Altered source versions must be plainly marked as such, and must
 *    not be misrepresented as being the original software.
 *
 * 3. This notice may not be removed or altered from any source
 *    distribution.
 *
 * $Id: RFC1867Parser.cpp 1822 2006-10-18 17:41:28Z svn $
 *****************************************************************************/

#ifndef TEMPLATE_INSTANTIATION
#include "Environment.h"
#endif

#include "RFC1867Parser.h"
#include "TemporaryFile.h"
#include "DirectoryCleaner.h"
#include "Auxiliary.h"
#include "Message.h"
#include "Macro.h"
#include "SourceInfo.h"

#ifndef TEMPLATE_INSTANTIATION
SOURCE_INFO_ADD("$Id: RFC1867Parser.cpp 1822 2006-10-18 17:41:28Z svn $");
#endif

#define AS_CONTENT(pointer)         reinterpret_cast<content_t *>(pointer)

#ifdef DEBUG_RFC1867Parser
#define DUMP_INPUT_AND_THROW(message) dump_input(buffer_); THROW(message)
#else
#define DUMP_INPUT_AND_THROW(message) THROW(message)
#endif

#ifdef DEBUG_RFC1867Parser
static bool parser_is_trace = false;
#define DUMP_TOKEN_AND_THROW(message) dump_read_token(); THROW(message)
#ifdef _MSC_VER
#define TRACE_FUNC if (parser_is_trace) cerr << "CALL: " << __LINE__ << endl
#else
#define TRACE_FUNC if (parser_is_trace) cerr << "CALL: " << __func__ << endl
#endif
#else
#define DUMP_TOKEN_AND_THROW(message) THROW(message)
#define TRACE_FUNC
#endif

template<class R, class W> const apr_size_t
RFC1867Parser<R, W>::READ_BLOCK_SIZE        = REQ_READ_BLOCK_SIZE;
template<class R, class W> const apr_size_t
RFC1867Parser<R, W>::READ_TIMEOUT_SEC       = REQ_READ_TIMEOUT_SEC;
template<class R, class W> const char
RFC1867Parser<R, W>::CR_LF[]                = "\r\n";
template<class R, class W> const char
RFC1867Parser<R, W>::MULTIPART_FORM_DATA[]  = "multipart/form-data; ";
template<class R, class W> const char
RFC1867Parser<R, W>::CONTENT_TYPE[]         = "Content-Type: ";
template<class R, class W> const char
RFC1867Parser<R, W>::CONTENT_DISPOSITION[]  = "Content-Disposition: ";
template<class R, class W> const char
RFC1867Parser<R, W>::FORM_DATA[]            = "form-data; ";
template<class R, class W> const char
RFC1867Parser<R, W>::BOUNDARY_PARAM[]       = "boundary";
template<class R, class W> const char
RFC1867Parser<R, W>::BOUNDARY_PREFIX[]      = "--";
template<class R, class W> const char
RFC1867Parser<R, W>::ASSIGN                 = '=';
template<class R, class W> const char
RFC1867Parser<R, W>::QUOTE                  = '"';
template<class R, class W> const char
RFC1867Parser<R, W>::DELIMITER              = ';';
template<class R, class W> const char
RFC1867Parser<R, W>::NAME_PARAM[]           = "name";
template<class R, class W> const char
RFC1867Parser<R, W>::FILENAME_PARAM[]       = "filename";
template<class R, class W> const char
RFC1867Parser<R, W>::FILE_NAME_TEMPLATE[]   = "post";

/******************************************************************************
 * public メソッド
 *****************************************************************************/
template<class R, class W>
RFC1867Parser<R, W>::RFC1867Parser(apr_pool_t *pool, PostReaderClass& reader,
                                   const char *file_dir_path,
                                   apr_size_t max_text_size,
                                   apr_uint64_t max_file_size,
                                   apr_size_t max_item_num,
                                   apr_size_t file_offset)
  : pool_(pool),
    reader_(reader),
    buffer_(READ_BLOCK_SIZE * 2), // 2 以上じゃないとまずい
    boundary_(NULL),
    boundary_len_(0),
    barrier_len_(0),
    file_dir_path_(file_dir_path),
    max_text_size_(max_text_size),
    max_file_size_(max_file_size),
    max_item_num_(max_item_num),
    file_offset_(file_offset)
{

}

template<class R, class W>
apr_array_header_t *RFC1867Parser<R, W>::parse(const char *content_type,
                                               apr_size_t content_size)
{
    apr_array_header_t *content_array;
    content_t content;
    const char *start;

    TRACE_FUNC;

    // 古い一時ファイルを削除
    DirectoryCleaner::clean_old_files(pool_, file_dir_path_, READ_TIMEOUT_SEC);

    // "仮の" サイズチェック
    if (content_size > ((max_text_size_+max_file_size_) * max_item_num_)) {
        THROW(MESSAGE_RFC1867_DATA_SIZE_TOO_LARGE);
    }

    content_array = apr_array_make(pool_,
                                   static_cast<int>(max_item_num_),
                                   static_cast<int>(sizeof(content_t)));

    boundary_ = get_boundary(content_type);
    boundary_len_ = strlen(boundary_);
    barrier_len_ = boundary_len_ + strlen(CR_LF) +
        strlen(BOUNDARY_PREFIX);

    if (fill() == 0) {
        THROW(MESSAGE_RFC1867_CONTENT_SIZE_ZERO);
    }

    start = skip_line(buffer_.get_data()) - strlen(CR_LF);
    buffer_.erase(start - buffer_.get_data());

    while (!is_end()) {
        if (static_cast<apr_size_t>(content_array->nelts) == max_item_num_) {
            THROW(MESSAGE_RFC1867_ITEM_COUNT_EXCEEDED);
        }

        get_content(&content);
        *AS_CONTENT(apr_array_push(content_array)) = content;
    }

    return content_array;
}

template<class R, class W>
typename RFC1867Parser<R, W>::content_t *
RFC1867Parser<R, W>::get_content(apr_array_header_t *content_array,
                                 const char *name)
{
    content_t *contents;

    contents = AS_CONTENT(content_array->elts);
    for (int i = 0; i < content_array->nelts; i++) {
        if (strncmp((contents + i)->name, name, strlen(name)) == 0) {
            return contents + i;
        }
    }

    return NULL;
}

template<class R, class W>
void RFC1867Parser<R, W>::dump_content_array(apr_array_header_t *content_array)
{
    content_t *contents;

    cout << "****************************************" << endl;

    contents = AS_CONTENT(content_array->elts);
    for (int i = 0; i < content_array->nelts; i++) {
        dump_content(contents + i);
        cout << "****************************************" << endl;
    }
}


/******************************************************************************
 * private メソッド
 *****************************************************************************/
template<class R, class W>
void RFC1867Parser<R, W>::get_content(content_t *content)
{
    const char *start;
    const char *line_end;

    TRACE_FUNC;

    if ((fill() == 0) && (buffer_.get_size() == 0)) {
        DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
    }

    if (!start_with(buffer_.get_data(), CR_LF)) {
        DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
    }

    start = buffer_.get_data();
    start += strlen(CR_LF);

    line_end = skip_line(start);

    start = skip(start, CONTENT_DISPOSITION);
    start = skip(start, FORM_DATA);

    start = get_param(start, line_end, NAME_PARAM, &(content->name));

    if (start == NULL) {
        DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
    }

    start = get_param(start, line_end, FILENAME_PARAM,
                      &(content->file.name));

    if (start == NULL) { // テキスト
        start = skip_line(line_end);
        buffer_.erase(start - buffer_.get_data());

        get_text_content(content);
    } else { // ファイル
        content->file.name = basename_ex(content->file.name);

        start = skip_line(start);
        line_end = skip_line(start);

        start = skip(start, CONTENT_TYPE);
        content->file.mime = AS_CONST_CHAR(apr_pstrmemdup(pool_,
                                                          start,
                                                          line_end - start -
                                                          strlen(CR_LF)));

        start = skip_line(line_end);
        buffer_.erase(start - buffer_.get_data());

        get_file_content(content);
    }
}

template<class R, class W>
void RFC1867Parser<R, W>::get_text_content(content_t *content)
{
    char *text;
    char *old_text;
    apr_size_t text_length;
    const char *end;
    apr_size_t read_size;
    apr_size_t tail_size;

    TRACE_FUNC;

    content->type = TEXT;

    text = NULL;
    text_length = 0;

    end = AS_CONST_CHAR(memmem(buffer_.get_data(), buffer_.get_size(),
                               boundary_, boundary_len_));

    if (end != NULL) {
        content->text = apr_pstrmemdup(pool_,
                                       buffer_.get_data(),
                                       end - buffer_.get_data() -
                                       strlen(CR_LF) -
                                       strlen(BOUNDARY_PREFIX));
        end += boundary_len_;
        buffer_.erase(end - buffer_.get_data());

        return;
    }

    try {
        MALLOC(text, char *, sizeof(char), buffer_.get_size()); // 多めに確保
        write_text(&text, &text_length, buffer_, barrier_len_);

        // このループは通常数回で抜けるはず
        while (UNLIKELY(true)) {
            if (text_length > max_text_size_) {
                DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_TEXT_SIZE_TOO_LARGE);
            }

            read_size = read();

            end = AS_CONST_CHAR(memmem(buffer_.get_data(), buffer_.get_size(),
                                       boundary_, boundary_len_));
            if (end != NULL) { // バッファ中にバウンダリが見つかった
                break;
            } else if (read_size == 0) { // バウンダリが見つからないまま入力終了
                DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
            }

            old_text = text;
            REALLOC(text, char *, sizeof(char),
                    text_length + buffer_.get_size());
            write_text(&text, &text_length, buffer_, barrier_len_);
        }

        old_text = text;
        tail_size = end -
            strlen(CR_LF) - strlen(BOUNDARY_PREFIX) -
            buffer_.get_data();

        APR_PALLOC(text, char *, pool_, text_length + tail_size + 1);

        memcpy(text, old_text, text_length);
        memcpy(text + text_length, buffer_.get_data(), tail_size);
        *(text + text_length + tail_size) = '\0';

        end += boundary_len_;
        buffer_.erase(end - buffer_.get_data());

        free(old_text);

        content->text = text;
    } catch(const char *) {
        if (text != NULL) {
            free(text);
        }
        throw;
    }
}

template<class R, class W>
void RFC1867Parser<R, W>::get_file_content(content_t *content)
{
    const char *end;
    apr_size_t read_size;
    apr_size_t write_size;

    TRACE_FUNC;

    content->type = FILE;

    TemporaryFile temp_file(pool_,
                            apr_pstrcat(pool_,
                                        file_dir_path_, "/", FILE_NAME_TEMPLATE,
                                        NULL),
                            false);
    temp_file.open(FileWriterClass::OPEN_FLAG);

    FileWriterClass writer(pool_, temp_file.get_handle(), file_offset_);
    MessageDigest5 digest;

    end = AS_CONST_CHAR(memmem(buffer_.get_data(), buffer_.get_size(),
                               boundary_, boundary_len_));

    if (end == NULL) { // まだバッファ中にバウンダリがない
        while (true) {
            if (UNLIKELY(writer.get_write_size() > max_file_size_)) {
                DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FILE_SIZE_TOO_LARGE);
            }

            // バウンダリが途中までバッファ中に含まれている可能性がある
            // のでそれを考慮して書き出すサイズを計算
            if (buffer_.get_size() < barrier_len_) {
                DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
            }

            write_size = buffer_.get_size() - barrier_len_;
            write_file(writer, digest, buffer_, write_size);

            read_size = fill();
            end = AS_CONST_CHAR(memmem(buffer_.get_data(), buffer_.get_size(),
                                       boundary_, boundary_len_));
            if (UNLIKELY(end != NULL)) { // バッファ中にバウンダリが見つかった
                break;
            } else if (UNLIKELY(read_size == 0)) { // バウンダリが見つからないまま入力終了
                DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
            }
        }
    }

    write_size = end - buffer_.get_data() -
        strlen(CR_LF) - strlen(BOUNDARY_PREFIX);

    write_file(writer, digest, buffer_, write_size);

    end = end - write_size + boundary_len_;
    buffer_.erase(end - buffer_.get_data());

    digest.finish();

    content->file.size = writer.get_write_size();
    content->file.digest = apr_pstrdup(pool_, digest.c_str());
    content->file.temp_path = temp_file.get_temp_path();
}

template<class R, class W>
bool RFC1867Parser<R, W>::is_end()
{
    TRACE_FUNC;

    fill();

    return start_with(buffer_.get_data(), BOUNDARY_PREFIX);
}

template<class R, class W>
apr_size_t RFC1867Parser<R, W>::fill()
{
    apr_size_t size;
    apr_size_t read_size = 0;

    TRACE_FUNC;

    while (buffer_.get_size() < READ_BLOCK_SIZE) {
        if ((size = read()) == 0) {
            break;
        }

        read_size += size;
    }

    return read_size;
}

template<class R, class W>
apr_size_t RFC1867Parser<R, W>::read(apr_size_t size)
{
    apr_size_t read_size = 0;

    TRACE_FUNC;

    reader_.read(buffer_.get_data_end(), size, &read_size);
    buffer_.add_size(read_size);

    return read_size;
}

template<class R, class W>
const char *RFC1867Parser<R, W>::get_boundary(const char *content_type)
{
    const char *boundary;
    const char *start;
    const char *end;

    TRACE_FUNC;

    start = skip(content_type, MULTIPART_FORM_DATA);
    end = content_type+strlen(content_type);

    if (get_param(start, end, BOUNDARY_PARAM, &boundary) == NULL) {
        THROW(MESSAGE_RFC1867_CONTENT_TYPE_INVALID);
    }

    if (strlen(boundary) <= strlen(BOUNDARY_PREFIX)) {
        THROW(MESSAGE_RFC1867_CONTENT_TYPE_INVALID);
    } else if (strlen(boundary) > (READ_BLOCK_SIZE/2)) {
        THROW(MESSAGE_RFC1867_CONTENT_TYPE_INVALID);
    }

    return boundary;
}

template<class R, class W>
const char *RFC1867Parser<R, W>::get_param(const char *input_start,
                                           const char *input_end,
                                           const char *name,
                                           const char **value)
{
    const char *end;

    TRACE_FUNC;

    while (isspace(*input_start & 0xff) || (*input_start == DELIMITER)) {
        input_start++;
    }

    if (!start_with(input_start, name)) {
        *value = NULL;

        return NULL;
    }

    input_start += strlen(name);
    if (*(input_start++) != ASSIGN) {
        *value = NULL;

        return NULL;
    }

    if (*input_start == QUOTE) {
        input_start++;
        end = strnchr(input_start, input_end-input_start, QUOTE);
        if (end == NULL) {
            DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
        }

        *value = static_cast<const char *>(apr_pstrmemdup(pool_, input_start,
                                                          end-input_start));
        end++;
    } else {
        end = input_start + 1;
        while ((end < input_end) && !isspace(*end & 0xff)) {
            end++;
        }
        *value = static_cast<const char *>(apr_pstrmemdup(pool_, input_start,
                                                          end-input_start));
    }

    return end;
}

template<class R, class W>
const char *RFC1867Parser<R, W>::skip(const char *input_start,
                                      const char *pattern, bool is_must)
{
    TRACE_FUNC;

    if (is_must && !start_with(input_start, pattern)) {
        DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
    }

    return input_start + strlen(pattern);
}

template<class R, class W>
void RFC1867Parser<R, W>::write_text(char **text, apr_size_t *text_length,
                                     RFC1867ParserBuffer& buffer,
                                     apr_size_t barrier_len)
{
    TRACE_FUNC;

    if (buffer.get_size() < barrier_len) {
        return;
    }

    memcpy(*text + *text_length, buffer.get_data(),
           buffer.get_size() - barrier_len);
    *text_length += buffer.get_size() - barrier_len;
    buffer.erase(buffer.get_size() - barrier_len);
}

template<class R, class W>
void RFC1867Parser<R, W>::write_file(FileWriterClass& writer,
                                     MessageDigest5& digest,
                                     RFC1867ParserBuffer& buffer,
                                     apr_size_t size)
{
    TRACE_FUNC;

    writer.write(buffer.get_data(), size);
    digest.update(AS_BYTE(buffer.get_data()), size);
    buffer.erase(size);
}

template<class R, class W>
const char *RFC1867Parser<R, W>::skip_line(const char *input_start)
{
    const char *start = strstr(input_start, CR_LF);

    TRACE_FUNC;

    if (start == NULL) {
        DUMP_INPUT_AND_THROW(MESSAGE_RFC1867_FORMAT_INVALID);
    }

    return start + strlen(CR_LF);
}

template<class R, class W>
bool RFC1867Parser<R, W>::start_with(const char *str, const char *pattern,
                                     apr_size_t pattern_length)
{
    TRACE_FUNC;

    return strncmp(str, pattern, pattern_length) == 0;
}

template<class R, class W>
void RFC1867Parser<R, W>::dump_content(content_t *content)
{
    cout << "name           : " << content->name << endl;

    if (content->type == TEXT) {
        cout << "type           : TEXT" << endl;
        cout << "value          : " << content->text << endl;
    } else if (content->type == FILE) {
        cout << "type           : FILE" << endl;
        cout << "file.name      : " << content->file.name << endl;
        cout << "file.temp_path : " << content->file.temp_path << endl;
        cout << "file.size      : " << content->file.size << endl;
        cout << "file.mime      : " << content->file.mime << endl;
        cout << "file.digest    : " << content->file.digest << endl;
    } else {
        cout << "type           : UNKNOWN" << endl;
    }
}

template<class R, class W>
void RFC1867Parser<R, W>::dump_input(RFC1867ParserBuffer& buffer)
{
    *(buffer.get_data() + buffer.get_size() - 1) = '\0';

    cerr << "INPUT:" << endl;
    cerr << buffer.get_data() << endl;
}


/******************************************************************************
 * テスト
 *****************************************************************************/
#ifdef DEBUG_RFC1867Parser
#include "TestRunner.h"

#include "File.h"
#include "CGIRequestReader.h"
#include "BasicFileWriter.h"

#include <fstream>

static const char FILE_CONTENT_NAME[]   = "file";
static const apr_size_t MAX_TEXT_SIZE   = 100 * 1024;
static const apr_size_t MAX_FILE_SIZE   = 100 * 1024 * 1024;
static const apr_size_t MAX_ITEM_NUM    = 10;

typedef RFC1867Parser<CGIRequestReader, BasicFileWriter> RFC1867ParserImpl;

void show_usage(const char *prog_name)
{
    cerr << "Usage: " << prog_name << " <POST_FILE_PATH> <CONTENT_TYPE>" << endl;
}

void run_parse(apr_pool_t *pool, const char *post_file_path,
               const char *content_type, const char *file_digest,
               const char *file_dir_path, apr_size_t dump_level)
{
    apr_array_header_t *content_array;
    RequestReader::progress_t progress;
    ifstream stream(post_file_path, ios_base::binary);
    RFC1867ParserImpl::content_t *file_content;

    show_test_name("parse");

    File post_file(pool, post_file_path);
    CGIRequestReader reader(&progress, NULL, &stream);

    RFC1867ParserImpl parser(pool, reader, file_dir_path, MAX_TEXT_SIZE,
                             MAX_FILE_SIZE, MAX_ITEM_NUM);

    content_array = parser.parse(content_type,
                                 static_cast<apr_size_t>(post_file.get_size()));

    if (dump_level > 2) {
        RFC1867ParserImpl::dump_content_array(content_array);
    }

    file_content = RFC1867ParserImpl::get_content(content_array,
                                                  FILE_CONTENT_NAME);
    if (strncmp(file_content->file.digest, file_digest,
                strlen(file_digest)) != 0) {
        THROW(MESSAGE_BUG_FOUND);
    }

    show_spacer();
}

void run_all(apr_pool_t *pool, int argc, const char * const *argv)
{
    const char *post_file_path;
    const char *content_type;
    const char *file_dir_path;
    const char *file_digest;
    apr_size_t dump_level;

    if (argc < 5) {
        THROW(MESSAGE_ARGUMENT_INVALID);
    }

    post_file_path = argv[1];
    content_type = argv[2];
    file_digest = argv[3];
    file_dir_path = argv[4];
    if (argc >= 6) {
        dump_level = atoi(argv[5]);
    } else {
        dump_level = 0;
    }

    if (!is_exist(pool, post_file_path)) {
        THROW(MESSAGE_FILE_NOT_FOUND);
    }

    show_item("post_file_path", post_file_path);
    show_item("content_type", content_type);
    show_item("file_dir_path", file_dir_path);

    show_line();

    run_parse(pool, post_file_path, content_type, file_digest, file_dir_path,
              dump_level);
}

#endif

// Local Variables:
// mode: c++
// coding: utf-8-dos
// End:
