# !/usr/bin/env python

import pickle
import argparse
import numpy as np
import mdtraj as md
from glob import glob
from itertools import chain
from contextlib import closing
from multiprocessing import Pool
from mdentropy.mutinf import MutualInformation
from mdentropy.utils import timing, shuffle, dihedrals
from itertools import combinations_with_replacement as combinations


def run(traj, nbins, iter, N):
    D = dihedrals(traj)
    n = np.unique(np.hstack(tuple(map(np.array, [df.columns for df in D]))))
    R = []
    for i in range(iter+1):
        r = np.zeros((n.size, n.size))
        g = MutualInformation(nbins, D)
        with timing(i):
            idx = np.triu_indices(n.size)
            with closing(Pool(processes=N)) as pool:
                r[idx] = pool.map(g, combinations(n, 2))
                pool.terminate()
            r[idx[::-1]] = r[idx]
            R.append(r)
            D = [shuffle(df) for df in D]
    if iter > 0:
        return R[0] - np.mean(R[1:], axis=0)
    return R[0]


def parse_cmdln():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-i', '--input', dest='traj',
                        help='File containing trajectory.', required=True)
    parser.add_argument('-s', '--shuffle-iter', dest='iter',
                        help='Number of shuffle iterations.',
                        default=100, type=int)
    parser.add_argument('-t', '--topology', dest='top',
                        help='File containing topology.', default=None)
    parser.add_argument('-b', '--n-bins', dest='nbins',
                        help='Number of bins', default=24, type=int)
    parser.add_argument('-n', '--n-proc', dest='N',
                        help='Number of processors to be used.',
                        default=12, type=int)
    parser.add_argument('-r','--stride',dest='stride',
                        help='Stride to use',default=100)
    parser.add_argument('-o', '--output', dest='out',
                        help='Name of output file.', default='mutinf.pkl')
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    options = parse_cmdln()
    expr = options.traj.replace(' ', '').split(',')
    files = list(chain(*map(glob, expr)))
    traj = md.load(files, top=options.top, stride=options.stride)
    M = run(traj, options.nbins, options.iter, options.N)
    pickle.dump(M, open(options.out, 'wb'))
