#! /usr/bin/env python
# -*- coding: utf-8 -*
##########################################################################
# NSAp - Copyright (C) CEA, 2016
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html for details.
##########################################################################

# System import
import os
import argparse
import json
from pprint import pprint
from datetime import datetime
import textwrap
from argparse import RawTextHelpFormatter

# Bredala module
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyconnectome.tractography.probabilist",
                     names=["mrtrix_tractogram"])
except:
    pass

# PyFreeSurfer import
from pyfreesurfer import DEFAULT_FREESURFER_PATH
from pyfreesurfer.wrapper import FSWrapper

# Package import
from pyconnectome import __version__ as version
from pyconnectome import DEFAULT_FSL_PATH
from pyconnectome.wrapper import FSLWrapper
from pyconnectome.tractography.probabilist import mrtrix_tractogram


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


DOC = """
Compute the tractogram of a diffusion dataset using MRtrix.

Requirements:
    - preprocessed DWI with bvals/bvecs: if distortion artifacts from
      acquisition have been properly corrected.
    - FreeSurfer recon-all or a T1 aligned to diffusion. T1 can be kept in
      its native resolution as long as it is registered to diffusion.


Example of command for HCP, using SIFT2:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_mrtrix_tractogram \
    -o /tmp/nsap/tractograms/mrtrix \
    -d /tmp/nsap/tractograms/mrtrix/tmp \
    -s 889579 \
    -i /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/data.nii.gz \
    -b /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/bvals \
    -r /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/bvecs \
    -n 5 \
    -T 1 \
    -L 250 \
    -C 0.06 \
    -I \
    -Z \
    -M /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/nodif_brain_mask.nii.gz \
    -A /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/T1w_acpc_dc_restore.nii.gz \
    -S /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w \
    -v 2

Example of command for HCP, using global tractography (experimental):

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_mrtrix_tractogram \
    -o /tmp/nsap/tractograms/mrtrix \
    -d /tmp/nsap/tractograms/mrtrix/tmp \
    -s 889579 \
    -i /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/data.nii.gz \
    -b /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/bvals \
    -r /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/bvecs \
    -G \
    -M /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/Diffusion/nodif_brain_mask.nii.gz \
    -A /neurospin/hcp/ANALYSIS/3T_freesurfer/889579/T1w/T1w_acpc_dc_restore.nii.gz \
    -S 
    -v 2
"""


def is_file(filepath):
    """ Check file's existence - argparse 'type' argument.
    """
    if not os.path.isfile(filepath):
        raise argparse.ArgumentError("File does not exist: %s" % filepath)
    return filepath


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="pyconnectome_mrtrix_tractogram",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-o", "--outdir",
        required=True, metavar="<path>",
        help="Directory where to output.")
    required.add_argument(
        "-d", "--tempdir",
        required=True, metavar="<path>",
        help="Where to write temporary directories e.g. /tmp.")
    required.add_argument(
        "-s", "--subject-id",
        required=True, metavar="<id>",
        help="Subject identifier.")
    required.add_argument(
        "-i", "--dwi",
        type=is_file, required=True, metavar="<path>",
        help="Path to the diffusion data.")
    required.add_argument(
        "-b", "--bvals",
        type=is_file, required=True, metavar="<path>",
        help="Path to the bvalue list.")
    required.add_argument(
        "-r", "--bvecs",
        type=is_file, required=True, metavar="<path>",
        help="Path to the list of diffusion-sensitized directions.")
    required.add_argument(
        "-n", "--nb-threads",
        required=True, type=int, metavar="<int>",
        help="Number of threads.")

    # Optional arguments
    parser.add_argument(
        "-G", "--global-tractography",
        action="store_true",
        help="If set run global tractography (tckglobal) instead of local "
             "(tckgen).")
    parser.add_argument(
        "-T", "--mtracks",
        type=int, metavar="<int>",
        help="For non-global tractography only. "
             "Number of millions of tracks in raw tractogram.")
    parser.add_argument(
        "-L", "--maxlength",
        type=int, metavar="<int>",
        help="For non-global tractography only. Max fiber length in mm.")
    parser.add_argument(
        "-C", "--cutoff",
        type=float, metavar="<float>",
        help="For non-global tractography only. FOD cutoff=stopping criteria.")
    parser.add_argument(
        "-I", "--seed-gmwmi",
        action="store_true",
        help="Set this option if you want to activate the "
             "'-seed_gmwmi' option of MRtrix 'tckgen', to "
             "seed from the GM/WM interface. Otherwise, and "
             "by default, the seeding is done in white matter "
             "('-seed_dynamic' option).")
    parser.add_argument(
        "-R", "--sift-mtracks", type=int, metavar="<int>",
        help="Number of millions of tracks to keep with SIFT. "
             "If not set, SIFT is not applied.")
    parser.add_argument(
        "-Z", "--sift2", action="store_true",
        help="To activate SIFT2.")
    parser.add_argument(
        "-B", "--nodif-brain",
        type=is_file, metavar="<path>",
        help="Diffusion brain-only Nifti volume with bvalue ~ 0. If not "
              "passed, it is generated automatically by averaging all the b0 "
              "volumes of the DWI.")
    parser.add_argument(
        "-M", "--nodif-brain-mask",
        type=is_file, metavar="<path>",
        help="Path to the Nifti brain binary mask in diffusion. If not "
        "passed, it is created with MRtrix 'dwi2mask'.")
    parser.add_argument(
        "-A", "--fast-t1-brain",
        type=is_file, metavar="<path>",
        help="By default FSL FAST is run on the FreeSurfer 'brain.mgz'. If "
             "you want the WM probability map to be computed from another T1, "
             "pass the T1 brain-only volume. Note that it has to be aligned "
             "with diffusion. This argument is useful for HCP, where some "
             "FreeSurfer 'brain.mgz' cannot be processed by FSL FAST.")
    parser.add_argument(
        "-S", "--subjects-dir", metavar="<path>",
        help="FreeSurfer subjects directory. To set or bypass the "
             "$SUBJECTS_DIR environment variable.")
    parser.add_argument(
        "-U", "--no-mif-gz",
        action="store_false", dest="mif_gz", default=True,
        help="To not compress MIF files: .mif instead of .mif.gz.")   
    parser.add_argument(
        "-D", "--delete-raw-tracks",
        action="store_true",
        help="To save disk space, delete the raw tracks when "
             "the connectome has been computed.")
    parser.add_argument(
        "-K", "--keep-dwi-mif",
        action="store_false", dest="delete_dwi_mif", default=True,
        help="To not delete <outdir>/DWI.mif, which is a copy "
             "of the input <dwi> in the .mif format.")
    parser.add_argument(
        "-H", "--fs-sh",
        type=is_file, metavar="<path>",
        help="Bash script initializing FreeSurfer's environment.")
    parser.add_argument(
        "-F", "--fsl-sh",
        type=is_file, metavar="<path>",
        help="Bash script initializing FSL's environment.")
    parser.add_argument(
        "-v", "--verbose",
        type=int, choices=[0, 1, 2], default=2,
        help="Increase the verbosity level: 0 silent, [1, 2] verbose.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    verbose = kwargs.pop("verbose")
    if kwargs["fs_sh"] is None:
        kwargs["fs_sh"] = DEFAULT_FREESURFER_PATH
    if kwargs["fsl_sh"] is None:
        kwargs["fsl_sh"] = DEFAULT_FSL_PATH

    return kwargs, verbose


"""
Parse the command line.
"""
inputs, verbose = get_cmd_line_args()
tool = "pyconnectome_mrtrix_tractogram"
timestamp = datetime.now().isoformat()
tool_version = version
fsl_version = FSLWrapper([], shfile=inputs["fsl_sh"]).version
freesurfer_version = FSWrapper([], inputs["fs_sh"]).version
params = locals()
runtime = dict([(name, params[name])
               for name in ("tool", "tool_version", "fsl_version",
                            "freesurfer_version", "timestamp")])
outputs = None
if verbose > 0:
    pprint("[info] Starting MRTrix3 tractogram ...")
    pprint("[info] Runtime:")
    pprint(runtime)
    pprint("[info] Inputs:")
    pprint(inputs)


"""
Start the tractogram computation.
"""
tracks, sift_tracks, sift2_weights = mrtrix_tractogram(**inputs)


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(inputs["outdir"], "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
params = locals()
outputs = dict([(name, params[name])
               for name in ("tracks", "sift_tracks", "sift2_weights")])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    pprint("[info] Outputs:")
    pprint(outputs)
