#! /usr/bin/env python
##########################################################################
# NSAp - Copyright (C) CEA, 2017
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System modules
from __future__ import print_function
import os
import shutil
import json
import csv
import argparse
from datetime import datetime
from pprint import pprint
import numpy
import textwrap
from argparse import RawTextHelpFormatter

# Bredala module
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyconnectome.metrics.schcc",
                     names=["metric_profile"])
    bredala.register("pyfreesurfer.plots.formatting",
                     names=["sort_features"])
except:
    pass

# Package import
from pyconnectome import __version__ as version
from pyconnectome.metrics.schcc import metric_profile

# Third party import
from pyfreesurfer.plots.formatting import sort_features


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


DOC = """
Network Analysis: compute the heat map with all the user metric profiles.

Command example on the HCP data:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_metrics_heatmap \
    -o /volatile/nsap/hcp/metrics \
    -i /neurospin/hcp/ANALYSIS/3T_mrtrix_reduced_connectome/mrtrix_metrics/*/network_features.json \
    -v 2
"""

def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg

def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="python pyconnectome_metrics_heatmap",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-o", "--outdir",
        required=True, metavar="<path>", type=is_directory,
        help="directory where to output.")
    required.add_argument(
        "-i", "--featurefiles",
        required=True, metavar="<files>", nargs="+", type=is_file,
        help="the regex to access the subject JSON network features.")

    # Optional arguments
    parser.add_argument(
        "-v", "--verbose",
        type=int, choices=[0, 1, 2], default=0,
        help="increase the verbosity level: 0 silent, [1, 2] verbose.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    verbose = kwargs.pop("verbose")

    return kwargs, verbose


"""
Parse the command line.
"""
inputs, verbose = get_cmd_line_args()
tool = "pyconnectome_metrics_heatmap"
timestamp = datetime.now().isoformat()
tool_version = version
params = locals()
runtime = dict([(name, params[name])
               for name in ("tool", "tool_version", "timestamp")])
outputs = None
if verbose > 0:
    print("[info] Starting Network Analysis: heatmap ...")
    print("[info] Runtime:")
    pprint(runtime)
    print("[info] Inputs:")
    pprint(inputs)


"""
Load the individual network features
"""
features = {}
for path in inputs["featurefiles"]:
    with open(path, "rt") as open_file:
        sid = path.split(os.sep)[-2]
        features[sid] = json.load(open_file)
if verbose > 0:
    print("[info] '{0}' features found..".format(len(features)))


"""
Compute the heatmap
"""
heatmap = []
for sid, metrics in features.items():
    profile, header = metric_profile(metrics)
    heatmap.append(profile)
    heatmap[-1].insert(0, sid)
header.insert(0, "subjects")
heatmap_file = os.path.join(inputs["outdir"], "network_heatmap.csv")
with open(heatmap_file, "wt") as open_file:
    writer = csv.writer(open_file, delimiter=",")
    writer.writerow(header)
    for line in heatmap:
        writer.writerow(line)


"""
Create a snap of the heatmap
"""
heatmap_snap = sort_features(
    features=numpy.asarray(heatmap[1:])[:, 1:],
    outdir=inputs["outdir"],
    name="network_heatmap",
    header=header[1:],
    verbose=verbose)


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(inputs["outdir"], "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
params = locals()
outputs = dict([(name, params[name])
                for name in ("heatmap_file", "heatmap_snap")])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    print("[info] Outputs:")
    pprint(outputs)

