#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright © 2009 marmuta <marmvta@gmail.com>
# Copyright © 2010 Francesco Fumanti <francesco.fumanti@gmx.net>
#
# This file is part of Onboard.
#
# Onboard is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Onboard is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import sys,math
import time
import codecs
import re
import gc
from contextlib import contextmanager
import lm

@contextmanager
def timeit(s):
    gc.collect()
    gc.collect()
    gc.collect()
    t = time.time()
    yield None
    print "time: %fms%s" % ((time.time() - t)*1000, " for "+s if s else "")


def main():
    global model # for debugging

    for  i in range(1):
        model = LanguageModel(order = int(sys.argv[2]))
        training, held_out, testing = model.read_corpus(sys.argv[1])
        context = [unicode(w) for w in sys.argv[3:] or [u""]]

        # incrementally add new ngrams
        model.train(training)
        model.train(held_out)   # abuse held_out for now
        model.train(testing)    # abuse testing for now
        #model.train([u"Mary has a little lamb"])
        #model.train([u"John read Moby Dick","Mary read a different book","She read a book by Cher"])
        #model.train([u"a a b"])
#        for k in range(5):
#            model.count_ngram([u"we"])
#            model.count_ngram([u"we", u"saw"])
#            model.count_ngram([u"we", u"saw", u"dolphins"])
#            model.count_ngram([u"saw"])
#            model.count_ngram([u"saw", u"dolphins"])
#            model.count_ngram([u"dolphins"])

        model.info()

        #model.test(testing)
        del training, held_out, testing

        with timeit("predict (50)"):
            choices = model.predict(context, 50)

        with timeit("predict (all)"):
            choices = model.predict(context)
        model.print_choices(context, choices)

        with timeit("save"):
            model.save("/tmp/en.lm")
        with timeit("load"):
            model.load("/tmp/en.lm")

        model.info()

        with timeit("predict"):
            choices = model.predict(context)
        model.print_choices(context, choices)

class LanguageModel(lm.LanguageModelDynamic):

    def __init__(self, filename = None, order = 3):
        self.filename = filename
        self.set_order(order)

    def read_corpus(self, filename):
        # read corpus
        #s = codecs.open(filename, encoding='utf-8').read() \
        s = codecs.open(filename, encoding='latin-1').read() \
            .replace("\r"," ") # remove carriage returns from Moby Dick

        # split into sentences including separaters (punctuation, newlines)
        sentences = re.findall(""".*?(?:[\.;:!?][\s\"]  # punctuation
                                      | \s*\\n\s*\\n)   # double newline
                               """, s, re.UNICODE|re.DOTALL|re.VERBOSE)
        # devide corpus into 3 sections: training, held_out, test
        r = range(len(sentences))
        sh = set(r[5::20])
        st = set(r[15::20])
        training  = [sentences[i] for i in set(r) - sh - st]
        held_out  = [sentences[i] for i in sh]
        testing   = [sentences[i] for i in st]
        print "sentences: total %d, training %d, held_out %d, testing %d" % \
              (len(sentences),len(training),len(held_out),len(testing))

        return training, held_out, testing


    def info(self):
        counts = [0]*self.order
        totals = [0]*self.order
        for ng in self.iter_ngrams():
            #print len(ng[0]), ng[0]
            counts[len(ng[0])-1] +=  1
            totals[len(ng[0])-1] += ng[1]

        for i,c in enumerate(counts):
            print "%d-grams: count %d, total %d" % \
                  (i+1, counts[i], totals[i])


    def test(self, sentences):
        print
        print "testing..."

        # test
        # print n-grams
        if 0:
            for ng in self.iter_ngrams():
                print ng

        # chech total number of n-grams with exactly one and two counts
        n1s = [0]*self.order
        n2s = [0]*self.order
        for ng in self.iter_ngrams():
            i = len(ng[0])-1
            count = ng[1]
            if count == 1: n1s[i] += 1
            if count == 2: n2s[i] += 1

        for i in range(self.order):
            print "%d: n1 %10d, n2 %10d" % (i+1,n1s[i],n2s[i])


        # entropy and perplexity
        self.test_sentences(sentences)
        self.test_sentences([u"I spent three years before the mast"])
        self.test_sentences([u"John read a book"])

        # test normalization, all numbers should be close to 1.0
        for i in range(self.order):
            history = [u"in", u"the", u"end"][:max(i,0)]
            history = [u"<s>", u"a", u"a"][:max(i,0)]
            ps = 0
            for w in self.get_vocabulary():
                ps += self.get_probability(history + [w])
            print "order %d: %10f" % (i+1,ps),history


    def train(self, sentences):
        """ find and count n-grams """
        for sentence in sentences:
            words = self.split_sentence(sentence)

            # extract and count n-grams
            for i,word in enumerate(parts):
                for n in xrange(self.order):
                    if i+n+1 <= len(words):
                        assert(n == len(words[i:i+n+1])-1)
                        self.count_ngram(words[i:i+n+1])


    def test_sentences(self, sentences):
        word_count = 0
        ngram_count = 0
        entropy = 0
        for sentence in sentences:
            parts,count = self.split_sentence(sentence)
            word_count += count   # excluding sentence begin/end

            # extract n-grams of maximum length
            for i in xrange(len(parts)):
                b = max(i-(self.order-1),0)
                e = min(i-(self.order-1)+self.order,len(parts))
                ngram = parts[b:e]
                if len(ngram) != 1 or ngram[0] != "<s>":  # excluding unigram <s>
                    p = self.get_probability(ngram)
                    e = math.log(p,2) if p else float("infinity")
                    entropy += e
                    ngram_count += 1
                    #print p, e, ngram

        entropy = -entropy/word_count if word_count else 0
        try:
            perplexity = 2 ** entropy
        except:
            perplexity = 0
        print "test: words %d, n-grams %d, entropy %f bit/word, perplexity %f" % \
              (word_count, ngram_count, entropy, perplexity)

    def split_sentence(self, sentence):
        # split into sentence parts (words, punctuation)
#        parts = re.findall(u"""  (?:http://)?[\w\-]+(?:\.[\w\-]+)+(?:/[~\w\-\.]+)*
#                                 (?:\#[\w\-\.]*)? (?:[\?&][\w\-]+=[\w\-:]*)*  # urls
#                               | (?:^|\s|\.|~)(?:/[\w\-\.]+)+           # file names
#                               | \w*[\-]?\d+(?:\.\d+)+(?:\-\w+)*   # version numbers
#                               | \d+(?:\.\d+)?                          # floats
#                               | (?:[\w]+(?:[-'][\w]+)*)                # words
#                               | \-\- | \+[\+\-] | [\+\-\*/=!<>&\^]= | <>
#                               | \"\" | \"\"\" | \'\' | \'\'\'
#                               | \(\) | \[[0-9]?\] | \{\} | $\{ | \):
#                               | [\.,;:!¡?¿=\"\'\-\+\*/<>\(\)\[\]\{\}$%&\#\|]""",
#                               sentence, re.UNICODE|re.VERBOSE)
#        parts = re.findall(u"(?:[\w]+(?:[-'][\w]+)*)",
#                           sentence, re.UNICODE|re.VERBOSE)
        parts = re.findall(u"[^\W\d]\w*(?:[-'][\w]+)*",
                           sentence, re.UNICODE|re.VERBOSE)
        if parts:
            parts = [u"<s>"] + parts + [u"<\s>"]
        return parts

    def split_context(self, context):
        n   = min(self.order, len(context))
        history = context[-n:-1]
        prefix  = context[-1]
        return history, prefix

    def print_choices(self, context, choices):
        history, prefix = self.split_context(context)
        print
        print "history:", history, "prefix '%s' " % prefix
        psum = 0
        counts = []
        for x in choices:
            ngram = history + [x[0]]
            psum += x[1]
            padding = max(self.order-len(context),0)
            ng = [u""]*padding + ngram
            counts.append([self.get_ngram_count(ng[i:]) for i in range(self.order)])

        print "Probability sum %f for %d results" % (psum,len(choices))   # ought to be 1.0 for the whole vocabulary
        for i,x in enumerate(choices[:20]):
            print "%10f " % x[1] + "".join("%8d " % c for c in counts[i]) + "'%s'" % x[0]


if __name__ == '__main__':
    main()

