#! /usr/bin/env python
##########################################################################
# NSAp - Copyright (C) CEA, 2016 - 2017
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System modules
from __future__ import print_function
import os
import shutil
import json
import argparse
from datetime import datetime
from pprint import pprint
import numpy
import networkx
import bct
import csv
import copy
import textwrap
from argparse import RawTextHelpFormatter

# Bredala module
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyconnectome.metrics.schcc",
                     names=["basic_network_analysis",
                            "advanced_network_analysis", "create_graph"])
    bredala.register("pyconnectome.plotting.network",
                     names=["get_surface_parcellation_centroids",
                            "plot_network", "dict2list"])
except:
    pass

# Package import
from pyconnectome import __version__ as version
from pyconnectome.metrics.schcc import basic_network_analysis
from pyconnectome.metrics.schcc import advanced_network_analysis
from pyconnectome.metrics.schcc import create_graph
from pyconnectome.utils.encoders import NetworkResultEncoder
from pyconnectome.plotting.network import get_surface_parcellation_centroids
from pyconnectome.plotting.network import plot_network
from pyconnectome.plotting.network import dict2list

# Pyfreesurfer import
from pyfreesurfer import DEFAULT_FREESURFER_PATH
from pyfreesurfer.utils.surftools import TriSurface


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


DOC = """
Network Analysis: compute different features of a network.

Command example on the HCP data:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_metrics \
    -s 100206 \
    -o /volatile/nsap/hcp/metrics \
    -m /neurospin/hcp/ANALYSIS/3T_mrtrix_reduced_connectome/mrtrix/100206/connectome_endvox.txt \
    -l /neurospin/hcp/ANALYSIS/3T_mrtrix_reduced_connectome/mrtrix/100206/labels.txt \
    -f /neurospin/hcp/ANALYSIS/3T_freesurfer/100206/T1w \
    -g \
    -i \
    -v 2
"""


def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg


def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="python pyconnectome_metrics",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-s", "--subjectid",
        required=True, metavar="<id>",
        help="Subject identifier.")
    required.add_argument(
        "-o", "--outdir",
        required=True, metavar="<path>", type=is_directory,
        help="directory where to output.")
    required.add_argument(
        "-m", "--connectomefile",
        required=True, metavar="<file>", type=is_file,
        help="the connectome matrix in CSV format with white space sparators.")
    required.add_argument(
        "-l", "--labelfile",
        required=True, metavar="<file>", type=is_file,
        help="the connectome labels.")

    # Optional arguments
    parser.add_argument(
        "-g", dest="graphics",
        action="store_true",
        help="if activated compute quality controls.")
    parser.add_argument(
        "-i", dest="interactive",
        action="store_true",
        help="if activated display an interactive windows.")
    parser.add_argument(
        "-p", dest="snapnx", action="store_true",
        help="if activated generate a snap of the network.")
    parser.add_argument(
        "-a", dest="animatenx",
        action="store_true",
        help="if activated display animate the network.")
    parser.add_argument(
        "-f", "--fsdir",
        metavar="<path>", type=is_directory,
        help="the Freesurfer home directory.")
    parser.add_argument(
        "-e", "--erase",
        action="store_true",
        help="if activated, clean the subject folder.")
    parser.add_argument(
        "-v", "--verbose",
        type=int, choices=[0, 1, 2], default=0,
        help="increase the verbosity level: 0 silent, [1, 2] verbose.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    verbose = kwargs.pop("verbose")
    if kwargs["fsdir"] is None:
        kwargs["fsdir"] = DEFAULT_FREESURFER_PATH

    return kwargs, verbose


"""
Parse the command line.
"""
inputs, verbose = get_cmd_line_args()
tool = "pyconnectome_metrics"
timestamp = datetime.now().isoformat()
tool_version = version
networkx_version = networkx.__version__
bct_version = bct.__version__
params = locals()
runtime = dict([(name, params[name])
               for name in ("tool", "tool_version", "timestamp",
                            "networkx_version", "bct_version")])
outputs = None
if verbose > 0:
    print("[info] Starting Network Analysis ...")
    print("[info] Runtime:")
    pprint(runtime)
    print("[info] Inputs:")
    pprint(inputs)
subjectdir = os.path.join(inputs["outdir"], inputs["subjectid"])
if inputs["erase"] and os.path.isdir(subjectdir):
    shutil.rmtree(subjectdir)
if not os.path.isdir(subjectdir):
    os.mkdir(subjectdir)


"""
Read labels and remove unknown regions
"""
fsl_inputs = False
if inputs["connectomefile"].endswith(".dot"):
    fsl_inputs = True
# FSL case: list of coords with label
if fsl_inputs:
    labels_array = numpy.loadtxt(inputs["labelfile"])
    connectome_labels = ["{0}-{1}".format(hemi, idx)
                         for _, _, _, hemi, idx in labels_array]
# Simple list with region names
else:
    connectome_labels = [
        l.rstrip("\n") for l in open(inputs["labelfile"]).readlines()]
    try:
        index = connectome_labels.index("Unknown")
        connectome_labels.pop(index)
        connectome = numpy.delete(connectome, index, axis=0)
        connectome = numpy.delete(connectome, index, axis=1)
    except:
        pass
if verbose > 0:
    print("[info] Processing connectome with {0} labels...".format(
        len(connectome_labels)))
    

"""
Read/symmetrize the connectome and remove loops
"""
# FSL case: not a connectome matrix yet
if fsl_inputs:
    matrix_size = len(connectome_labels)
    connectome = numpy.zeros((matrix_size, matrix_size), dtype=numpy.single)
    with open(inputs["connectomefile"]) as dotfile:
        for line in dotfile:
            split = line.strip().split()
            idx1, idx2 = map(int, split[:2])
            nb_connections = float(split[2])
            connectome[idx1 - 1, idx2 - 1] = nb_connections
# A connectome matrix is provided
else:
    connectome = []
    with open(inputs["connectomefile"], "rt") as csvfile:
        reader = csv.reader(csvfile, delimiter=" ")
        for row in reader:
            connectome.append(row)
    connectome = numpy.asarray(connectome).astype(numpy.single)
    connectome += connectome.T
for i in range(connectome.shape[0]):
    connectome[i, i] = 0
if verbose > 0:
    print("[info] Processing connectome with shape {0}...".format(
        connectome.shape))


"""
Generate a graph
"""
graph = create_graph(connectome, connectome_labels)


"""
Start basic and advanced analysis
"""
outdir = None
if inputs["graphics"]:
    outdir = subjectdir
basic_network_features, basic_snapfiles = basic_network_analysis(
    graph, outdir=outdir)
advanced_network_features, advanced_snapfiles  = advanced_network_analysis(
    graph, kstep=1, sstep=600., outdir=outdir)
network_features = copy.deepcopy(basic_network_features)
network_features.update(advanced_network_features)


"""
Save result
"""
network_features_file = os.path.join(subjectdir, "network_features.json")
with open(network_features_file, "wt") as open_file:
    json.dump(network_features, open_file, sort_keys=True, check_circular=True,
              indent=4, cls=NetworkResultEncoder)


"""
Load the FreeSurfer parcellations
"""
if inputs["graphics"]:
    rh_white = os.path.join(inputs["fsdir"], inputs["subjectid"], "surf",
                            "rh.white")
    rh_annot = os.path.join(inputs["fsdir"], inputs["subjectid"], "label",
                            "rh.aparc.annot")
    lh_white = os.path.join(inputs["fsdir"], inputs["subjectid"], "surf",
                            "lh.white")
    lh_annot = os.path.join(inputs["fsdir"], inputs["subjectid"], "label",
                            "lh.aparc.annot")
    lh_surf = TriSurface.load(lh_white, annotfile=lh_annot)
    rh_surf = TriSurface.load(rh_white, annotfile=rh_annot)


"""
Render the network
"""
if inputs["graphics"]:
    centroids = get_surface_parcellation_centroids(
        lh_surf, rh_surf, connectome_labels)
    nodes = centroids.values()
    edges = graph.edges()
    edge_weights = [elem[2]["weight"] for elem in graph.edges(data=True)]
    weights = numpy.asarray(dict2list(network_features["strengths"]))
    plot_network(nodes, connectome_labels, weights=weights, edges=edges,
                 weight_node_by_color=True, weight_node_by_size=False,
                 edge_weights=edge_weights, weight_edge_by_color=True,
                 weight_edge_by_size=True, animate=inputs["animatenx"],
                 snap=inputs["snapnx"], interactive=inputs["interactive"],
                 outdir=subjectdir, actor_ang=(90, 180, 180))   


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(subjectdir, "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
params = locals()
outputs = dict([(name, params[name])
                for name in ("network_features_file", "basic_snapfiles",
                             "advanced_snapfiles")])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    print("[info] Outputs:")
    pprint(outputs)

