#!/usr/bin/env python

import pickle, sys, copy, random, time

init_board = (
    ( 'X', 'X', 1, 1, 1, 'X', 'X' ),
    ( 'X', 'X', 1, 1, 1, 'X', 'X' ),
    ( 1, 1, 1, 1, 1, 1, 1 ),
    ( 1, 1, 1, 0, 1, 1, 1 ),
    ( 1, 1, 1, 1, 1, 1, 1 ),
    ( 'X', 'X', 1, 1, 1, 'X', 'X' ),
    ( 'X', 'X', 1, 1, 1, 'X', 'X' )
    )

bad_pos1 = (
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    (  1,  0, 0, 'X', 'X', 'X', 'X' ),
    (  0,  0, 0, 'X', 'X', 'X', 'X' ),
    (  0,  0, 'X', 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' )
    )

bad_pos2 = (
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    (  1,  0, 0, 'X', 'X', 'X', 'X' ),
    (  0,  0, 0, 'X', 'X', 'X', 'X' ),
    (  1,  0, 0, 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' )
    )

bad_pos3 = (
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    (  0,  0, 0, 'X', 'X', 'X', 'X' ),
    (  1,  0, 0, 'X', 'X', 'X', 'X' ),
    (  0,  0, 0, 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' ),
    ( 'X', 'X', 'X', 'X', 'X', 'X', 'X' )
    )

def generate_moves(board):
    cand = []
    # Look for empty slots
    empty = []
    for row in range(0, 7):
        for col in range(0, 7):
            if board[row][col] == 0:
                empty.append((row, col))
    for e in empty:
        r, c = e
        try:
            if board[r+1][c] == 1 and board[r+2][c] == 1 and (r+2 < 7):
                cand.append((e, (r+1, c), (r+2, c) ))
        except IndexError:
            pass
        try:
            if board[r-1][c] == 1 and board[r-2][c] == 1 and (r-2 >= 0):
                cand.append((e, (r-1, c), (r-2, c)))
        except IndexError:
            pass
        try:
            if board[r][c+1] == 1 and board[r][c + 2] == 1 and (c+2 < 7):
                cand.append((e, (r, c+1), (r, c+2)))
        except IndexError:
            pass
        try:
            if board[r][c - 1] == 1 and board[r][c - 2] == 1 and (c-2 >= 0):
                cand.append((e, (r, c-1), (r, c-2)))
        except IndexError:
            pass
    return cand

def print_board(board):
    for rows in board:
        for col in rows:
            print col, ' ',
        print
    print
    sys.stdout.flush()
    
def make_move(board, c):
    #print 'trying move: ', c
    nboard = list(map(list, board))
    empty, removed, moved = c
    r, c = empty
    nboard[r][c] = 1
    r, c = removed
    nboard[r][c] = 0
    r, c = moved
    nboard[r][c] = 0

    #print_board(nboard)
    nboard = tuple(map(tuple, nboard))
    return nboard

def sum(list):
    t = 0
    for i in list:
        t = t + i
    return t

def reject(nboard):
    if rejects.has_key(nboard):
        return 1
    return 0

# Transpose the board
def lister(*args): return args 
def transpose (m): 
    return tuple(apply(map, [lister]+list(m)))

def mirrors(board):
    mirs = [ board ]
    # N - S mirror
    board1 = list(board)
    board1.reverse()
    board1 = tuple(board1)
    mirs.append(board1)
    # E - W mirror
    board2 = map(list, board)
    map(lambda x: x.reverse(), board2)
    board2 = tuple(map(tuple, board2))
    mirs.append(board2)

    # 90 deg rotation
    board3 = transpose(board)
    mirs.append(board3)

    # N - S mirror after 90
    board1 = list(board3)
    board1.reverse()
    board1 = tuple(board1)
    mirs.append(board1)
    # E - W mirror after 90
    board2 = map(list, board3)
    map(lambda x: x.reverse(), board2)
    board2 = tuple(map(tuple, board2))
    mirs.append(board2)
    
    return mirs

# XXX: make more efficient
def match(b1, b2):
    for i in range(0, 7):
        for j in range(0, 7):
            if b1[i][j] != b2[i][j] and b2[i][j] != 'X':
                return 0
    return 1
                
def bad_position(board):
    for b in bad_pos:
        if match(board, b):
            return 1
    return 0

def cmp(b1, b2):
    # This strategy avoids bad positions
    #m1 = bad_position(b1)
    #m2 = bad_position(b2)
    # return m1 > m2

    # This strategy goes to the position with most posibilities for the next
    # move 
    m1 = len(generate_moves(b1))
    m2 = len(generate_moves(b2))
    return m1 < m2

class compress:
    def __init__(self):
        row1 = map(None, [0] * 3, range(2,5))
        row2 = map(None, [1] * 3, range(2,5))
        row3 = map(None, [2] * 7, range(0,7))
        row4 = map(None, [3] * 7, range(0,7))
        row5 = map(None, [4] * 7, range(0,7))
        row6 = map(None, [5] * 3, range(2,5))
        row7 = map(None, [6] * 3, range(2,5))
        pos = reduce(lambda x,y: x+y, [ row1, row2, row3, row4, row5, row6, row7])
        self.pos = map(None, range(0, 33), pos)

        u1 =['X'] * 7
        u = []
        for i in range(0,7):
            u.append(u1[:])
        self.board = u

    def compress(self, board):
        c = 0L
        for p in self.pos:
            shift, pos = p
            i, j = pos
            c = c | (board[i][j] << shift)
        return c

    def uncompress(self, cboard):
        u = copy.deepcopy(self.board)
        for p in self.pos:
            shift, pos = p
            i, j = pos
            if (cboard & (1L << shift)):
                u[i][j] = 1
            else:
                u[i][j] = 0
      
        return tuple(map(tuple, u))

def sort(nboards):
    #nboards.sort(cmp)
    # Try random shuffling
    #rand.shuffle(nboards)
    return nboards

def search_tree(board, solution):
    global t1, t2, max, pos_searched
    pos_searched = pos_searched + 1
    if pos_searched % 1000 == 0:
        t2 = time.time()
        print pos_searched, t2 - t1
        sys.stdout.flush()
        t1 = time.time()
        
    #print len(solution)
    if len(solution) >= max:
        max = len(solution)
        print_board(board)
    if len(solution) == 32:
        print 'success!'
        print solution
        sys.exit(0)
        
    candidates = generate_moves(board)
    nboards = [board] * len(candidates)
    nboards = map(make_move, nboards, candidates)
    #prioritize moves
    nboards = sort(nboards)
    for nboard in nboards:
        cnboard = comp.compress(nboard)
        if reject(cnboard):
            continue
        nsolution = solution[:]
        nsolution.append(cnboard)
        search_tree(nboard, nsolution)
    mirs = mirrors(board)
    mirs = map(comp.compress, mirs)
    # Accelerate finding bad positions
    for m in mirs:
        rejects[m] = 1

if __name__ == '__main__':
    try:
        global moves, rejects, max, rand, comp, pos_searched, t2, t1
	rand = random.Random()
        moves = []
        max = 0
        pos_searched = 0
        t1 = time.time()
        try:
            f = open('brain.rejects')
            u = pickle.Unpickler(f)
            rejects = u.load()
            f.close()
        except:
            rejects = {}
        board = init_board
        comp = compress()
        cboard = comp.compress(board)
        solution = [ cboard ]

        # Compute known bad position set
        bad_pos = {}
        for b in mirrors(bad_pos1):
            bad_pos[b] = 1
        for b in mirrors(bad_pos2):
            bad_pos[b] = 1
        for b in mirrors(bad_pos3):
            bad_pos[b] = 1
        bad_pos = bad_pos.keys()
        #bad_pos = map(comp.compress, bad_pos)

        search_tree(board, solution)
        print 'failure!'
    except KeyboardInterrupt:
        print len(rejects)
        f = open('brain.rejects', 'w')
        pickle.dump(rejects, f)
        f.close()
