# Copyright 2012 Matt Chaput. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#    1. Redistributions of source code must retain the above copyright notice,
#       this list of conditions and the following disclaimer.
#
#    2. Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# The views and conclusions contained in the software and documentation are
# those of the authors and should not be interpreted as representing official
# policies, either expressed or implied, of Matt Chaput.

from ast import literal_eval

from whoosh.compat import b, bytes_type, text_type, integer_types, PY3
from whoosh.compat import iteritems, dumps, loads, xrange
from whoosh.codec import base
from whoosh.matching import ListMatcher
from whoosh.reading import TermInfo, TermNotFound

if not PY3:
    class memoryview:
        pass

_reprable = (bytes_type, text_type, integer_types, float)


# Mixin classes for producing and consuming the simple text format

class LineWriter(object):
    def _print_line(self, indent, command, **kwargs):
        self._dbfile.write(b("  ") * indent)
        self._dbfile.write(command.encode("latin1"))
        for k, v in iteritems(kwargs):
            if isinstance(v, memoryview):
                v = bytes(v)
            if v is not None and not isinstance(v, _reprable):
                raise TypeError(type(v))
            self._dbfile.write(("\t%s=%r" % (k, v)).encode("latin1"))
        self._dbfile.write(b("\n"))


class LineReader(object):
    def __init__(self, dbfile):
        self._dbfile = dbfile

    def _reset(self):
        self._dbfile.seek(0)

    def _find_line(self, indent, command, **kwargs):
        for largs in self._find_lines(indent, command, **kwargs):
            return largs

    def _find_lines(self, indent, command, **kwargs):
        while True:
            line = self._dbfile.readline()
            if not line:
                return

            c = self._parse_line(line)
            if c is None:
                return

            lindent, lcommand, largs = c
            if lindent == indent and lcommand == command:
                matched = True
                if kwargs:
                    for k in kwargs:
                        if kwargs[k] != largs.get(k):
                            matched = False
                            break

                if matched:
                    yield largs
            elif lindent < indent:
                return

    def _parse_line(self, line):
        line = line.decode("latin1")
        line = line.rstrip()
        l = len(line)
        line = line.lstrip()
        if not line or line.startswith("#"):
            return None

        indent = (l - len(line)) // 2

        parts = line.split("\t")
        command = parts[0]
        args = {}
        for i in xrange(1, len(parts)):
            n, v = parts[i].split("=")
            args[n] = literal_eval(v)
        return (indent, command, args)

    def _find_root(self, command):
        self._reset()
        c = self._find_line(0, command)
        if c is None:
            raise Exception("No root section %r" % (command,))


# Codec class

class PlainTextCodec(base.Codec):
    length_stats = False

    def per_document_writer(self, storage, segment):
        return PlainPerDocWriter(storage, segment)

    def field_writer(self, storage, segment):
        return PlainFieldWriter(storage, segment)

    def per_document_reader(self, storage, segment):
        return PlainPerDocReader(storage, segment)

    def terms_reader(self, storage, segment):
        return PlainTermsReader(storage, segment)

    def new_segment(self, storage, indexname):
        return PlainSegment(indexname)


class PlainPerDocWriter(base.PerDocumentWriter, LineWriter):
    def __init__(self, storage, segment):
        self._dbfile = storage.create_file(segment.make_filename(".dcs"))
        self._print_line(0, "DOCS")
        self.is_closed = False

    def start_doc(self, docnum):
        self._print_line(1, "DOC", dn=docnum)

    def add_field(self, fieldname, fieldobj, value, length):
        if value is not None:
            value = dumps(value, 2)
        self._print_line(2, "DOCFIELD", fn=fieldname, v=value, len=length)

    def add_column_value(self, fieldname, columnobj, value):
        self._print_line(2, "COLVAL", fn=fieldname, v=value)

    def add_vector_items(self, fieldname, fieldobj, items):
        self._print_line(2, "VECTOR", fn=fieldname)
        for text, weight, vbytes in items:
            self._print_line(3, "VPOST", t=text, w=weight, v=vbytes)

    def finish_doc(self):
        pass

    def close(self):
        self._dbfile.close()
        self.is_closed = True


class PlainPerDocReader(base.PerDocumentReader, LineReader):
    def __init__(self, storage, segment):
        self._dbfile = storage.open_file(segment.make_filename(".dcs"))
        self._segment = segment
        self.is_closed = False

    def doc_count(self):
        return self._segment.doc_count()

    def doc_count_all(self):
        return self._segment.doc_count()

    def has_deletions(self):
        return False

    def is_deleted(self, docnum):
        return False

    def deleted_docs(self):
        return frozenset()

    def _find_doc(self, docnum):
        self._find_root("DOCS")
        c = self._find_line(1, "DOC")
        while c is not None:
            dn = c["dn"]
            if dn == docnum:
                return True
            elif dn > docnum:
                return False
            c = self._find_line(1, "DOC")
        return False

    def _iter_docs(self):
        self._find_root("DOCS")
        c = self._find_line(1, "DOC")
        while c is not None:
            yield c["dn"]
            c = self._find_line(1, "DOC")

    def _iter_docfields(self, fieldname):
        for _ in self._iter_docs():
            for c in self._find_lines(2, "DOCFIELD", fn=fieldname):
                yield c

    def _iter_lengths(self, fieldname):
        return (c.get("len", 0) for c in self._iter_docfields(fieldname))

    def doc_field_length(self, docnum, fieldname, default=0):
        for dn in self._iter_docs():
            if dn == docnum:

                c = self._find_line(2, "DOCFIELD", fn=fieldname)
                if c is not None:
                    return c.get("len", default)
            elif dn > docnum:
                break

        return default

    def _column_values(self, fieldname):
        for i, docnum in enumerate(self._iter_docs()):
            if i != docnum:
                raise Exception("Missing column value for field %r doc %d?"
                                % (fieldname, i))

            c = self._find_line(2, "COLVAL", fn=fieldname)
            if c is None:
                raise Exception("Missing column value for field %r doc %d?"
                                % (fieldname, docnum))

            yield c.get("v")

    def has_column(self, fieldname):
        for _ in self._column_values(fieldname):
            return True
        return False

    def column_reader(self, fieldname, column):
        return list(self._column_values(fieldname))

    def field_length(self, fieldname):
        return sum(self._iter_lengths(fieldname))

    def min_field_length(self, fieldname):
        return min(self._iter_lengths(fieldname))

    def max_field_length(self, fieldname):
        return max(self._iter_lengths(fieldname))

    def has_vector(self, docnum, fieldname):
        if self._find_doc(docnum):
            if self._find_line(2, "VECTOR"):
                return True
        return False

    def vector(self, docnum, fieldname, format_):
        if not self._find_doc(docnum):
            raise Exception
        if not self._find_line(2, "VECTOR"):
            raise Exception

        ids = []
        weights = []
        values = []
        c = self._find_line(3, "VPOST")
        while c is not None:
            ids.append(c["t"])
            weights.append(c["w"])
            values.append(c["v"])
            c = self._find_line(3, "VPOST")

        return ListMatcher(ids, weights, values, format_,)

    def _read_stored_fields(self):
        sfs = {}
        c = self._find_line(2, "DOCFIELD")
        while c is not None:
            v = c.get("v")
            if v is not None:
                v = loads(v)
            sfs[c["fn"]] = v
            c = self._find_line(2, "DOCFIELD")
        return sfs

    def stored_fields(self, docnum):
        if not self._find_doc(docnum):
            raise Exception
        return self._read_stored_fields()

    def iter_docs(self):
        return enumerate(self.all_stored_fields())

    def all_stored_fields(self):
        for _ in self._iter_docs():
            yield self._read_stored_fields()

    def close(self):
        self._dbfile.close()
        self.is_closed = True


class PlainFieldWriter(base.FieldWriter, LineWriter):
    def __init__(self, storage, segment):
        self._dbfile = storage.create_file(segment.make_filename(".trm"))
        self._print_line(0, "TERMS")

    @property
    def is_closed(self):
        return self._dbfile.is_closed

    def start_field(self, fieldname, fieldobj):
        self._fieldobj = fieldobj
        self._print_line(1, "TERMFIELD", fn=fieldname)

    def start_term(self, btext):
        self._terminfo = TermInfo()
        self._print_line(2, "BTEXT", t=btext)

    def add(self, docnum, weight, vbytes, length):
        self._terminfo.add_posting(docnum, weight, length)
        self._print_line(3, "POST", dn=docnum, w=weight, v=vbytes)

    def finish_term(self):
        ti = self._terminfo
        self._print_line(3, "TERMINFO",
                         df=ti.doc_frequency(), weight=ti.weight(),
                         minlength=ti.min_length(), maxlength=ti.max_length(),
                         maxweight=ti.max_weight(),
                         minid=ti.min_id(), maxid=ti.max_id())

    def add_spell_word(self, fieldname, text):
        self._print_line(2, "SPELL", fn=fieldname, t=text)

    def close(self):
        self._dbfile.close()


class PlainTermsReader(base.TermsReader, LineReader):
    def __init__(self, storage, segment):
        self._dbfile = storage.open_file(segment.make_filename(".trm"))
        self._segment = segment
        self.is_closed = False

    def _find_field(self, fieldname):
        self._find_root("TERMS")
        if self._find_line(1, "TERMFIELD", fn=fieldname) is None:
            raise TermNotFound("No field %r" % fieldname)

    def _iter_fields(self):
        self._find_root()
        c = self._find_line(1, "TERMFIELD")
        while c is not None:
            yield c["fn"]
            c = self._find_line(1, "TERMFIELD")

    def _iter_btexts(self):
        c = self._find_line(2, "BTEXT")
        while c is not None:
            yield c["t"]
            c = self._find_line(2, "BTEXT")

    def _find_term(self, fieldname, btext):
        self._find_field(fieldname)
        for t in self._iter_btexts():
            if t == btext:
                return True
            elif t > btext:
                break
        return False

    def _find_terminfo(self):
        c = self._find_line(3, "TERMINFO")
        return TermInfo(**c)

    def __contains__(self, term):
        fieldname, btext = term
        return self._find_term(fieldname, btext)

    def indexed_field_names(self):
        return self._iter_fields()

    def terms(self):
        for fieldname in self._iter_fields():
            for btext in self._iter_btexts():
                yield (fieldname, btext)

    def terms_from(self, fieldname, prefix):
        self._find_field(fieldname)
        for btext in self._iter_btexts():
            if btext < prefix:
                continue
            yield (fieldname, btext)

    def items(self):
        for fieldname, btext in self.terms():
            yield (fieldname, btext), self._find_terminfo()

    def items_from(self, fieldname, prefix):
        for fieldname, btext in self.terms_from(fieldname, prefix):
            yield (fieldname, btext), self._find_terminfo()

    def term_info(self, fieldname, btext):
        if not self._find_term(fieldname, btext):
            raise TermNotFound((fieldname, btext))
        return self._find_terminfo()

    def matcher(self, fieldname, btext, format_, scorer=None):
        if not self._find_term(fieldname, btext):
            raise TermNotFound((fieldname, btext))

        ids = []
        weights = []
        values = []
        c = self._find_line(3, "POST")
        while c is not None:
            ids.append(c["dn"])
            weights.append(c["w"])
            values.append(c["v"])
            c = self._find_line(3, "POST")

        return ListMatcher(ids, weights, values, format_, scorer=scorer)

    def close(self):
        self._dbfile.close()
        self.is_closed = True


class PlainSegment(base.Segment):
    def __init__(self, indexname):
        base.Segment.__init__(self, indexname)
        self._doccount = 0

    def codec(self):
        return PlainTextCodec()

    def set_doc_count(self, doccount):
        self._doccount = doccount

    def doc_count(self):
        return self._doccount

    def should_assemble(self):
        return False