#-*- coding:utf-8 -*-

#  Copyright © 2009-2015  B. Clausius <barcc@gmx.de>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.


import re
from collections import namedtuple

from .debug import debug


class MoveData (namedtuple('_MoveData', 'axis slice dir')):     # pylint: disable=W0232
    # axis  = 0...
    # slice = 0...dim-1 or -1 for all slices
    # dir   = 0 or 1
    __slots__ = ()
    def inverted(self):
        return self._replace(dir=not self.dir)  # pylint: disable=E1101

class MoveDataPacked (namedtuple('_MoveDataPacked', 'axis counts fullcount')):     # pylint: disable=W0232
    # axis  = 0...
    # counts = [...],  length: model.sizes[axis], values: number of slice-rotations, >0 CW <0 CCW
    # fullcount = number of full rotations, >0 CW <0 CCW
    __slots__ = ()
    
    def length(self):
        return sum(abs(c) for c in self.counts)
        
    def get_full_moves(self, model):
        symmetry = model.symmetries[self.axis]
        vals = [0] * symmetry
        for i, scount in enumerate(self.counts):
            for s in range(symmetry):
                v = scount + s
                if v > symmetry//2:
                    v -= symmetry
                vals[s] += abs(v)
        minval = vals[0]+1
        for i, v in enumerate(vals):
            if v < minval:
                f = i
                minval = v
        f = symmetry - f + self.fullcount
        if f > symmetry//2:
            f -= symmetry
        return f
        
    def update_fullmoves(self, model):
        def counts_offset(fullcount):
            symmetry = model.symmetries[self.axis]
            for sdirs in self.counts:
                sdirs -= fullcount
                if sdirs < 0:
                    sdirs += symmetry
                if sdirs > symmetry // 2:
                    sdirs -= symmetry
                yield sdirs
        fullcount = self.get_full_moves(model)
        counts = list(counts_offset(fullcount))
        return self._replace(counts=counts, fullcount=fullcount)
        
        
class _MoveQueueBase:
    class MoveQueueItem:
        def __init__(self, data, *, mark_after=False, code=None):
            assert isinstance(data, (MoveData, MoveDataPacked))
            self.data = data
            self.mark_after = mark_after
            self.code = code
        axis = property(lambda self: self.data.axis)
        slice = property(lambda self: self.data.slice)
        dir = property(lambda self: self.data.dir)
        def copy(self):
            return self.__class__(self.data, mark_after=self.mark_after, code=self.code)
        def invert(self):
            self.data = self.data.inverted()
        def rotate_by(self, model, totalmove):
            rmove = model.rotate_move(totalmove.data, self.data)
            self.data = self.data.__class__(*rmove)
            self.code = None
        @classmethod
        def normalize_complete_moves(cls, model, totalmoves):
            for move in model.normalize_complete_moves(move.data for move in totalmoves):
                yield cls(MoveData(*move))
        def unpack(self):
            maxis, sslices, fullmoves = self.data
            # slice rotations
            for mslice, sdirs in enumerate(sslices):
                mdir = sdirs < 0
                for unused in range(abs(sdirs)):
                    yield maxis, mslice, mdir
            # total rotations
            mdir = fullmoves < 0
            for unused in range(abs(fullmoves)):
                yield maxis, -1, mdir
        def update_fullmoves(self, model):
            self.data = self.data.update_fullmoves(model)
            
    def __init__(self):
        self.current_place = self._pos_0
        self.moves = []
        
    def copy(self):
        movequeue = self.__class__()
        movequeue.current_place = self.current_place
        movequeue.moves = [m.copy() for m in self.moves]
        return movequeue
        
    def __str__(self):
        return '{}(len={})'.format(self.__class__.__name__, len(self.moves))
        
        
class MoveQueue (_MoveQueueBase):
    _pos_0 = 0
    
    @classmethod
    def new_from_code(cls, code, pos, model):
        moves = cls()
        mpos, cpos = moves.parse(code, pos, model)
        return moves, mpos, cpos
        
    def at_start(self):
        return self.current_place == 0
    def at_end(self):
        return self.current_place == self.queue_length
    @property
    def _prev(self):
        return self.moves[self.current_place-1]
    @property
    def _current(self):
        return self.moves[self.current_place]
    @property
    def queue_length(self):
        return len(self.moves)

    def push(self, move_data, **kwargs):
        new_element = self.MoveQueueItem(MoveData(*move_data), **kwargs)
        self.moves.append(new_element)

    def push_current(self, move_data):
        if not self.at_end():
            self.moves[self.current_place:] = []
        self.push(move_data)

    def current(self):
        return None if self.at_end() else self._current.data

    def retard(self):
        if not self.at_start():
            self.current_place -= 1
    
    def rewind_start(self):
        self.current_place = 0
        
    def forward_end(self):
        self.current_place = self.queue_length
        
    def truncate(self):
        self.moves[self.current_place:] = []
        
    def truncate_before(self):
        self.moves[:self.current_place] = []
        self.current_place = 0
        
    def reset(self):
        self.current_place = 0
        self.moves[:] = []
        
    def advance(self):
        if not self.at_end():
            self.current_place += 1
        
    def swapnext(self):
        cp = self.current_place
        if cp+1 < self.queue_length:
            ma, mb = self.moves[cp], self.moves[cp+1]
            ma.code = mb.code = None
            ma.data, mb.data = mb.data, ma.data
            self.advance()
            
    def swapprev(self):
        cp = self.current_place
        if 0 < cp < self.queue_length:
            ma, mb = self.moves[cp-1], self.moves[cp]
            ma.code = mb.code = None
            ma.data, mb.data = mb.data, ma.data
            self.retard()
            
    def invert(self):
        # a mark at the end of the moves is discarded because a mark at start is not supported
        mark = False
        for move in self.moves:
            move.invert()
            move.code = None
            move.mark_after, mark = mark, move.mark_after
        self.moves.reverse()
        if not (self.at_start() or self.at_end()):
            self.current_place = len(self.moves) - self.current_place
            
    def normalize_complete_rotations(self, model):
        totalmoves = []
        new_moves = []
        for i, move in enumerate(self.moves):
            if i == self.current_place:
                self.current_place = len(new_moves)
            if move.slice < 0:
                totalmoves.append(move)
            else:
                for totalmove in reversed(totalmoves):
                    move.rotate_by(model, totalmove)
                new_moves.append(move)
        totalmoves = list(self.MoveQueueItem.normalize_complete_moves(model, totalmoves))
        self.moves = new_moves + totalmoves
        
    def unpack(self, move):
        for move in move.unpack():
            self.push(move)
        
    def normalize_moves(self, model):
        mqp = MoveQueuePacked()
        for i, move in enumerate(self.moves):
            if i == self.current_place:
                mqp.current_place = (len(mqp.moves)-1, mqp.moves[-1].data.length()) if mqp.moves else (0, 0)
            mqp.pack(move.data, model)
        if len(self.moves) <= self.current_place:
            mqp.current_place = len(mqp.moves), 0
        self.reset()
        for i, move in enumerate(mqp.moves):
            move.update_fullmoves(model)
            if i == mqp.current_place[0]:
                self.current_place = len(self.moves) + mqp.current_place[1]
            self.unpack(move)
        if len(mqp.moves) <= mqp.current_place[0]:
            self.current_place = len(self.moves)
        
    def is_mark_current(self):
        return self.at_start() or self._prev.mark_after
        
    def is_mark_after(self, pos):
        return self.moves[pos].mark_after
        
    def mark_current(self, mark=True):
        if not self.at_start():
            self._prev.mark_after = mark
            if self._prev.code is not None:
                self._prev.code = self._prev.code.replace(' ','')
                if mark:
                    self._prev.code += ' '
                    
    def mark_and_extend(self, other):
        if not other.moves:
            return -1
        self.mark_current()
        self.truncate()
        self.moves += other.moves
        return self.current_place + other.current_place
        
    def format(self, model):
        code = ''
        pos = 0
        for i, move in enumerate(self.moves):
            if move.code is None:
                move.code = MoveFormat.format(move, model)
            code += move.code
            if i < self.current_place:
                pos = len(code)
        return code, pos
        
    def parse_iter(self, code, pos, model):
        code = code.lstrip(' ')
        queue_pos = self.current_place
        move_code = ''
        for i, c in enumerate(code):
            if move_code and MoveFormat.isstart(c, model):
                data, mark = MoveFormat.parse(move_code, model)
                if data is not None:
                    #FIXME: invalid chars at start get lost, other invalid chars are just ignored
                    self.push(data, mark_after=mark, code=move_code)
                yield data, queue_pos, i
                if i == pos:
                    queue_pos = self.queue_length
                move_code = c
            else:
                move_code += c
            if i < pos:
                queue_pos = self.queue_length + 1
        if move_code:
            data, mark = MoveFormat.parse(move_code, model)
            if data is not None:
                self.push(data, mark_after=mark, code=move_code)
            if len(code)-len(move_code) < pos:
                queue_pos = self.queue_length
            yield data, queue_pos, len(code)
            
    def parse(self, code, pos, model):
        queue_pos = 0
        cpos = 0
        for unused_data, queue_pos, code_len in self.parse_iter(code, pos, model):
            if cpos < pos:
                cpos = code_len
        return queue_pos, cpos
        
        
class MoveQueuePacked (_MoveQueueBase):
    _pos_0 = (0, 0, 0)
    
    def length(self):
        i = 0
        for move in self.moves:
            i += move.data.length()
        return i
        
    def push(self, move_data, **kwargs):
        new_element = self.MoveQueueItem(MoveDataPacked(*move_data), **kwargs)
        self.moves.append(new_element)
        
    def pack(self, move, model):
        def getstack():
            if self.moves:
                return self.moves[-1].data
            return None
        def newsmove(maxis):
            return MoveDataPacked(maxis, [0] * model.sizes[maxis], 0)
        def addslice(scounts, move):
            sdir = -1 if move.dir else 1
            if move.slice < 0:
                for s in range(model.sizes[move.axis]):
                    scounts[s] += sdir
            else:
                scounts[move.slice] += sdir
        def symmetry(smove):
            ssymmetry = model.symmetries[smove.axis]
            for i, scount in enumerate(smove.counts):
                if scount < 0:
                    scount += ssymmetry
                if scount > ssymmetry//2:
                    scount -= ssymmetry
                smove.counts[i] = scount
                
        smove = getstack()
        if smove is None or smove.axis != move.axis:
            smove = newsmove(move.axis)
            self.push(smove)
        addslice(smove.counts, move)
        symmetry(smove)
        if not any(smove.counts) and not smove.fullcount:
            self.moves.pop()
        return smove
            
    def format(self, model):
        mq = MoveQueue()
        for move in self.moves:
            mq.unpack(move)
        return mq.format(model)
        
                
class MoveFormat: # pylint: disable=W0232
    re_flubrd = re.compile(r"(.)(\d*)(['-]?)([^ ]*)( *)(.*)")
    
    @classmethod
    def isstart(cls, char, model):
        return char.upper() in model.faces
    
    @staticmethod
    def intern_to_str_move(move, model):
        if move.slice <= -1:
            # Rotate entire cube
            if move.dir:
                mface = model.symbolsI[move.axis]
                if mface in model.faces:
                    # opposite symbol not reversed
                    return mface, '', ''
                # no opposite symbol
            mface = model.symbols[move.axis]
            mdir = '-' if move.dir else ''
            return mface, '', mdir
        elif move.slice*2 > model.sizes[move.axis]-1:
            mface = model.symbolsI[move.axis]
            if mface in model.faces:
                # slice is nearer to the opposite face
                mface = mface.lower()
                mslice = model.sizes[move.axis]-1 - move.slice
                mslice = str(mslice+1) if mslice else ''
                mdir = '' if move.dir else '-'
                return mface, mslice, mdir
            # no opposite face
        mface = model.symbols[move.axis].lower()
        mslice = str(move.slice+1) if move.slice else ''
        mdir = '-' if move.dir else ''
        return mface, mslice, mdir
    
    @classmethod
    def format(cls, move, model):
        mface, mslice, mdir = cls.intern_to_str_move(move, model)
        #print('format:', move.data, '->', (mface, mslice, mdir))
        mark = ' ' if move.mark_after else ''
        move_code = mface + mslice + mdir + mark
        return move_code
        
    @staticmethod
    def str_to_intern_move(tface, tslice, tdir, model):
        mface = tface.upper()
        mslice = int(tslice or 1) - 1
        mdir = bool(tdir)
        if mface not in model.faces:
            return None
        elif mface in model.symbols:
            maxis = model.symbols.index(mface)
        elif mface in model.symbolsI:
            maxis = model.symbolsI.index(mface)
            mslice = model.sizes[maxis]-1 - mslice
            mdir = not mdir
        else:
            assert False, 'faces is a subset of symbols+symbolsI'
        if mslice < 0 or mslice >= model.sizes[maxis]:
            return None
        if tface.isupper():
            mslice = -1
        return MoveData(maxis, mslice, mdir)
        
    @classmethod
    def parse(cls, move_code, model):
        tface, tslice, tdir, err1, mark, err2 = cls.re_flubrd.match(move_code).groups()
        mark = bool(mark)
        move = cls.str_to_intern_move(tface, tslice, tdir, model)
        #print('parse:', (tface, tslice, tdir), '->', move)
        if move is None:
            debug('Error parsing formula')
            return move, False
        return move, mark
        
        
