# htmltree.py by hylom
# -*- coding: utf-8 -*-

"""htmltree.py - HTML Element-Tree Builder
by hylom <hylomm@@single_at_mark@@gmail.com>
"""

import HTMLParser
import re

class HTMLElementError(Exception):
    def __init__(self, msg, elem):
        self.msg = msg
        self.elem = elem

    def __repr__(self):
        str = "HTML Element Error: %s in %s" % (self.msg, self.elem)
        return str

class Renderer(object):
    """HTMLElement Render base class."""
    def attrs2str(self, elem):
        strs = []
        for attr in elem.attrs:
            if elem.attrs[attr] == None:
                strs.append(attr)
            elif "'" in elem.attrs[attr]:
                strs.append('%s="%s"' % (attr, elem.attrs[attr]))
            else:
                strs.append("%s='%s'" % (attr, elem.attrs[attr]))
        strs.insert(0, "")
        return " ".join(strs)

class HTMLRenderer(Renderer):
    """Render HTMLElement as HTML."""
    # TODO: check tags not need to close more strict...
    UNCLOSABLE_TAGS = ["br", "link", "meta", "img"]

    def render_inner(self, elem):
        texts = []
        for child in elem:
            self._recursive(child, texts)
        return "".join(texts)

    def render(self, elem):
        texts = []
        self._recursive(elem, texts)
        return "".join(texts)

    def _recursive(self, elem, texts):
        if elem.is_tag():
            texts.append("<" + elem.name + self.attrs2str(elem) + ">")
            for child in elem:
                self._recursive(child, texts)
            if not elem.name in self.UNCLOSABLE_TAGS:
                texts.append("</" + elem.name + ">")
        elif elem.is_text():
            if elem.text():
                texts.append(elem.text())
        elif elem.is_root():
            for child in elem:
                self._recursive(child, texts)
        elif elem.is_decl():
            texts.append("<!" + elem.name + ">")
        elif elem.is_comment():
            texts.append("<!--" + elem.name + "-->")


class TEXTRenderer(Renderer):
    """Render HTMLElement as TEXT."""
    # TODO: check tags not need to close more strict...
    UNCLOSABLE_TAGS = ["br", "link", "meta", "img"]

    def render_inner(self, elem):
        texts = []
        for child in elem:
            self._recursive(child, texts)
        return "".join(texts)

    def render(self, elem):
        texts = []
        self._recursive(elem, texts)
        return "".join(texts)

    def _recursive(self, elem, texts):
        if elem.is_tag():
            for child in elem:
                self._recursive(child, texts)
        elif elem.is_text():
            if elem.text():
                texts.append(elem.text())
        elif elem.is_root():
            for child in elem:
                self._recursive(child, texts)

class HTMLElement(list):
    """HTML element object to use as tree nodes."""
    ROOT = 0
    TAG = 100
    TEXT = 200
    DECL = 300
    COMMENT = 400

    def __init__(self, type, name="", attrs={}):
        """
        create HTMLElement object.

        Arguments:
        type -- element type. HTMLElement.(ROOT|TAG|TEXT)
        name -- element name (default: "")
        attrs -- dict of attributes (default:{})

        Example:
        attr = dict(href="http://example.com/", target="_blank")
        e = HTMLElement(HTMLElement.TAG, "a", attr)
        # 'e' means <a href="http://example.com/" target="_blank">
        """

        self.type = type
        self.name = name
        self.attrs = dict(attrs)
        self._text = ""
        self._parent = None
        self._next_elem = None
        self._prev_elem = None

    def __repr__(self):
        if self.type == HTMLElement.TAG:
            return "<TAG:%s %s>" % (self.name, self._attrs2str())
        elif self.type == HTMLElement.DECL:
            return "<DECL:'%s'>" % self.name
        elif self.type == HTMLElement.COMMENT:
            return "<COMMENT:'%s'>" % self.name
        elif self.type == HTMLElement.TEXT:
            return "<TEXT:'%s'>" % self._text
        else:
            return "<UNKNOWN>"

    def __eq__(self, other):
        return id(self) == id(other)

    def _attrs2str(self):
        str = []
        f = lambda x,y: x if y == None else "%s='%s'" % (x,y)

        strs = [f(x,self.attrs[x]) for x in self.attrs]
        return " ".join(strs)

    # basic acquision functions
    def get_attribute(self, attr, default=None):
        """returns given attribute's value."""
        return self.attrs.get(attr, default)

    def attr(self, attr, default=None):
        """returns given attribute's value."""
        return self.attrs.get(attr, default)

    def has_attribute(self, attr):
        """returns True if element has "attr" attribute."""
        return attr in self.attrs

    def text(self):
        """returns content in the tag."""
        return self._text

    def inner_html(self):
        "returns inner html"
        rn = HTMLRenderer()
        return rn.render_inner(self)

    def inner_text(self):
        "returns inner text"
        rn = TEXTRenderer()
        return rn.render_inner(self)

    def get_classes(self):
        "returns classes"
        attr = self.get_attribute('class')
        if attr == None:
            return []
        return attr.split()

    # navigation functions
    def parent(self):
        """returns tag's parent element."""
        return self._parent

    def next(self):
        """returns tag's next element."""
        return self._next_elem

    def prev(self):
        """returns tag's previous element."""
        return self._prev_elem

    def next_tag(self):
        """returns tag's next tag."""
        next = self.next()
        while(next != None):
            if next.is_tag():
                break
            next = next.next()
        return next

    def prev_tag(self):
        """returns tag's previous tag."""
        prev = self.prev()
        while(prev != None):
            if prev.is_tag():
                break
            prev = prev.prev()
        return prev

    # basic query functions
    def get_elements_by_name(self, name):
        buf = []
        for i in self:
            i._r_get_elements_by_name(name, buf)
        return buf

    def _r_get_elements_by_name(self, name, buf):
        if self.name == name:
            buf.append(self)
        for i in self:
            i._r_get_elements_by_name(name, buf)

    def get_comments(self):
        buf = []
        for i in self:
            i._r_get_comments(buf)
        return buf

    def _r_get_comments(self, buf):
        if self.is_comment():
            buf.append(self)
        for i in self:
            i._r_get_comments(buf)

    def get_element_by_id(self, id):
        for i in self:
            if "id" in i.attrs and i.attrs["id"] == id:
                return i
            e = i.get_element_by_id(id)
            if e != None:
                return e
        #raise HTMLElementError("Element not found")
        return None

    def get_elements_by_class(self, cls):
        buf = []
        for i in self:
            i._r_get_elements_by_class(cls, buf)
        return buf

    def _r_get_elements_by_class(self, cls, buf):
        if cls in self.get_classes():
            buf.append(self)
        for i in self:
            i._r_get_elements_by_class(cls, buf)

    def get_elements(self, name, attrs):
        elems = self.get_elements_by_name(name)
        results = []
        for elem in elems:
            for name in attrs:
                if elem.get_attribute(name, "") != attrs[name]:
                    break
            else:
                results.append(elem)
        return results

    # manipulation functions
    def append_tag(self, tag, attrs):
        elem = HTMLElement(HTMLElement.TAG, tag, attrs)
        self.append(elem)

    def remove_element(self, elem):
        parent = elem.parent()
        parent.remove(elem)

    def delete(self):
        p = self.parent()
        p.remove(self)

    # query functions
    # TODO: this function is under implementing...
    def select(self, expr):
        terms = expr.strip().split()
        if len(terms) == 0:
            return []
        results = self
        for pat in terms:
            t = []
            for elem in results:
                t.extend(self._select_pattern(pat, elem))
            results = t
        return results

    def _select_pattern(self, pat, elem):
        results = []
        if pat[0] == "#":
            results = [elem.get_element_by_id(pat[1:]),]
        elif pat[0] == ".":
            results = elem.get_elements_by_class(pat[1:])
        else:
            results = elem.get_elements_by_name(pat)
        return [x for x in results if x]

    def select_1st(self, expr):
        r = self.select(expr)
        if len(r) == 0:
            return None
        else:
            return r[0]

    def select_by_name2(self, term1, term2):
        tbl = self.get_elements_by_name(term1)
        buf = []
        for elem in tbl:
            st = elem.get_elements_by_name(term2)
            buf.extend(st)
        return buf

    # is_* functions
    def is_text(self):
        return self.type == HTMLElement.TEXT

    def is_tag(self):
        return self.type == HTMLElement.TAG

    def is_root(self):
        return self.type == HTMLElement.ROOT

    def is_decl(self):
        return self.type == HTMLElement.DECL

    def is_comment(self):
        return self.type == HTMLElement.COMMENT

    def is_descendant(self, tagname):
        p = self.parent()
        while p != None:
            if p.name == tagname:
                return p
            p = p.parent()
        return False

    # mmmh....
    def trace_back(self, tag):
        """ regexp string => list"""
        p = self.parent()
        rex = re.compile(tag)
        result = []
        while p != None:
            if rex.search(p.name):
                result.append(p.name)
            p = p.parent()
        return result


class HTMLTreeError(Exception):
    def __init__(self, msg, lineno, offset):
        self.msg = msg
        self.lineno = lineno
        self.offset = offset

    def __repr__(self):
        str = "HTML Parse Error: %s , line: %d, char: %d" % (self.msg, self.lineno, self.offset)
        return str


def parse(data, charset=None, option=0):
    "parse HTML and returns HTMLTree object"
    tree = HTMLTree()
    tree.parse(data, charset, option)
    return tree


class HTMLTree(HTMLParser.HTMLParser):
    "HTML Tree Builder"
    USE_VALIDATE = 0x0001

    IGNORE_BLANK = 0x0010
    TRUNC_BLANK  = 0x0020
    JOIN_TEXT    = 0x0040

    TRUNC_BR = 0x0100
    # TODO: check tags not need to close more strict...
    UNCLOSABLE_TAGS = ["br", "link", "meta", "img", "input"]

    def __init__(self):
        "Constructor"
        HTMLParser.HTMLParser.__init__(self)

    def parse(self, data, charset=None, option=0):
        """
        Parse given HTML.

        Arguments:
        data -- HTML to parse
        charset -- charset of HTML (default: None)
        option -- option (default: 0, meaning none)
        
        """

        self.charset = charset
        self._htmlroot = HTMLElement(HTMLElement.ROOT)
        self._cursor = self._htmlroot
        self._option = option
        try:
            self.feed(data)
        except HTMLParser.HTMLParseError, e:
            raise HTMLTreeError("HTML parse error: " + e.msg,
                                e.lineno, e.offset)

        # if charset is not given, detect charset
        if self.charset == None:
            r = self.root()
            metas = r.get_elements_by_name("meta")
            for meta in metas:
                if meta.attrs.get("http-equiv", None) == "Content-Type":
                    ctype = meta.attrs.get("content", "")
                    m = re.search(r"charset=([^;]+)", ctype)
                    if m:
                        self.charset = m.group(1)
                    else:
                        self.charset = None
                        
            if self.charset:
                self._htmlroot = HTMLElement(HTMLElement.ROOT)
                self._cursor = self._htmlroot
                self.feed(data)

        self._finalize()

    def _finalize(self):
        r = self.root()
        self._r_finalize(r)

    def _r_finalize(self, elem):
        if elem.is_text():
            return
        
        l = len(elem)
        if l > 1:
            elem[0]._next_elem = elem[1]
            elem[-1]._prev_elem = elem[-2]
        if l > 2:
            for i in range(1, l-1): # 1 to l-2
                elem[i]._prev_elem = elem[i-1]
                elem[i]._next_elem = elem[i+1]

        for sub_elem in elem:
            self._r_finalize(sub_elem)

    def validate(self):
        r = self.root()
        self._r_validate(self, e)

    # tools
    def _text_encoder(self, text):
        # text encode check and convert.
        # if charset is given, convert text to unicode type.
        val = ""
        if self.charset:
            try:
                val = unicode(text, self.charset)
            except TypeError:
                # self.charset is utf-8.
                val = text
        else:
            # treat as unicode input
            val = text
        return val

    def _attr_encoder(self, attrs):
        return [(k, self._text_encoder(v)) for (k, v) in attrs]

    # Handlers
    def handle_starttag(self, tag, attrs):
        # some tags treat as start-end tag.
        if tag in self.UNCLOSABLE_TAGS:
            return self.handle_startendtag(tag, attrs)
            
        elem = HTMLElement(HTMLElement.TAG, tag, self._attr_encoder(attrs))

        if self._option & HTMLTree.USE_VALIDATE > 0:
            # try validation (experimental)
            if tag == "li" and self._cursor.name == "li":
                self.handle_endtag("li")
            # end of validation

        elem._parent = self._cursor
        self._cursor.append(elem)
        self._cursor = elem

    def handle_endtag(self, tag):
        # some tags treat as start-end tag.
        if tag in self.UNCLOSABLE_TAGS:
            return

        self._cursor = self._cursor.parent()

    def handle_startendtag(self, tag, attrs):
        elem = HTMLElement(HTMLElement.TAG, tag, self._attr_encoder(attrs))
        elem._parent = self._cursor
        self._cursor.append(elem)

    def handle_data(self, data):
        if self._option & HTMLTree.IGNORE_BLANK > 0:
            if re.search(r"^\s*$", data):
                data = ""

        elem = HTMLElement(HTMLElement.TEXT)
        elem._parent = self._cursor

        # encode text to utf-8
        elem._text = self._text_encoder(data)

        self._cursor.append(elem)

    def handle_entityref(self, name):
        data = "&" + name + ";"
        self.handle_data(data)

    def handle_charref(self, ref):
        data = "&#" + ref + ";"
        self.handle_data(data)

    def handle_decl(self, decl):
        elem = HTMLElement(HTMLElement.DECL, decl)
        elem._parent = self._cursor
        self._cursor.append(elem)

    def handle_comment(self, data):
        elem = HTMLElement(HTMLElement.COMMENT, data)
        elem._parent = self._cursor
        self._cursor.append(elem)

    # Accessor
    def root(self):
        return self._htmlroot
