# !/usr/bin/env python

import pickle
import argparse
import numpy as np
import mdtraj as md
import pandas as pd
from glob import glob
from itertools import chain
from itertools import product
from contextlib import closing
from multiprocessing import Pool
from mdentropy.utils import timing, dihedrals, shuffle
from mdentropy.tent import TransferEntropy


def run(current, past, nbins, iter, N):
    cD = dihedrals(current)
    pD = dihedrals(past)
    n = np.unique(np.hstack(tuple(map(np.array, [df.columns for df in cD]))))
    R = []
    for i in range(iter+1):
        g = TransferEntropy(nbins, cD, pD)
        with timing(i):
            with closing(Pool(processes=N)) as pool:
                CMI = np.reshape(pool.map(g, product(n, n)),
                                 (n.size, n.size)).T
                pool.terminate()
            R.append(CMI)
            cD = [shuffle(df) for df in cD]
    T = R[0]
    if iter > 0:
        T -= np.mean(R[1:], axis=0)
    return pd.DataFrame(T - T.T, columns=n)


def parse_cmdln():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-c', '--current', dest='current',
                        help='File containing current step states.',
                        required=True)
    parser.add_argument('-p', '--past', dest='past',
                        help='File containing past step states.',
                        required=True)
    parser.add_argument('-s', '--shuffle-iter', dest='iter',
                        help='Number of shuffle iterations.',
                        default=10, type=int)
    parser.add_argument('-r','--stride',dest='stride',
                        help='Stride to use',default=100)
    parser.add_argument('-t', '--topology',
                        dest='top', help='File containing topology.',
                        default=None)
    parser.add_argument('-b', '--n-bins', dest='nbins',
                        help='Number of bins', default=24, type=int)
    parser.add_argument('-n', '--n-proc', dest='N',
                        help='Number of processors', default=12, type=int)
    parser.add_argument('-o', '--output', dest='out',
                        help='Name of output file.', default='tent.pkl')
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    options = parse_cmdln()
    expr1, expr2 = (options.current.replace(' ', '').split(','),
                    options.past.replace(' ', '').split(','))
    f1, f2 = (list(chain(*map(glob, expr1))),
              list(chain(*map(glob, expr2))))
    current = md.load(f1, top=options.top)
    past = md.load(f2, top=options.top, stride=options.r)
    D = run(current, past, options.nbins, options.iter, options.N)
    pickle.dump(D, open(options.out, 'wb'))
