#! /usr/bin/env python
# -*- coding: utf-8 -*
##########################################################################
# NSAp - Copyright (C) CEA, 2016
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System modules
from __future__ import print_function
import os
import shutil
import glob
import json
import argparse
import nibabel
from datetime import datetime
from pprint import pprint
import numpy
import textwrap
from argparse import RawTextHelpFormatter


# Bredala module
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyconnectome.tractography.probabilist",
                     names=["probtrackx2"])
    bredala.register("pyconnectome.utils.regtools",
                     names=["freesurfer_bbregister_t1todif", "flirt"])
    bredala.register("pyconnectome.utils.segtools",
                     names=["white_matter_interface"])
except:
    pass

# PyFreeSurfer import
from pyfreesurfer import DEFAULT_FREESURFER_PATH
from pyfreesurfer.wrapper import FSWrapper

# Package import
from pyconnectome import __version__ as version
from pyconnectome import DEFAULT_FSL_PATH
from pyconnectome.wrapper import FSLWrapper
from pyconnectome.tractography.probabilist import probtrackx2
from pyconnectome.utils.regtools import freesurfer_bbregister_t1todif
from pyconnectome.utils.regtools import flirt
from pyconnectome.utils.segtools import white_matter_interface


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


# Command parameters
doc = """
Perform FSL probabilistic tractography in a single subject using a seeding
point in voxel coordinates.

Example of command for HCP, single seed point:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_probtrackx2_tractogram \
    -o /tmp/nsap/tractograms/fsl \
    -b /neurospin/hcp/PROCESSED/3T_bedpostx/102008/T1w/Diffusion.bedpostX \
    -s 71 120 59 \
    -i 1 \
    -A 1000 \
    -L 0.5 \
    -v 2

Example of command for HCP, use a user specified mask to seed:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_probtrackx2_tractogram \
    -o /tmp/nsap/tractograms/fsl \
    -b /neurospin/hcp/PROCESSED/3T_bedpostx/102008/T1w/Diffusion.bedpostX \
    -P /neurospin/hcp/ANALYSIS/3T_freesurfer/102008/T1w/102008/mri/aparc+aseg.mgz \
    -S 102008 \
    -D /neurospin/hcp/ANALYSIS/3T_freesurfer/102008/T1w \
    -B /neurospin/hcp/ANALYSIS/3T_mrtrix/102008/nodif_brain.nii.gz \
    -A 1000 \
    -K 1010 \
    -L 0.5 \
    -v 2

Example of command for HCP, use a user specified mask to seed:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_probtrackx2_tractogram \
    -o /tmp/nsap/tractograms/fsl/nomask \
    -b /neurospin/hcp/PROCESSED/3T_bedpostx/102008/T1w/Diffusion.bedpostX \
    -a /neurospin/hcp/ANALYSIS/3T_freesurfer/102008/T1w/T1w_acpc_dc_restore_brain.nii.gz \
    -S 102008 \
    -D /neurospin/hcp/ANALYSIS/3T_freesurfer/102008/T1w \
    -B /neurospin/hcp/ANALYSIS/3T_mrtrix/102008/nodif_brain.nii.gz \
    -A 1000 \
    -L 0.5 \
    -v 2
"""


def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg


def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="pyconnectome_probtrackx2_tractogram",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=textwrap.dedent(doc))

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-o", "--outdir",
        type=is_directory, required=True, metavar="<path>",
        help="Directory where to output.")
    required.add_argument(
        "-b", "--bedpostxdir",
        type=is_directory, required=True, metavar="<path>",
        help="The FSL bedpostx directory.")
    required.add_argument(
        "-s", "--point",
        type=int, metavar="<seed>", nargs=3,
        help="A seeding point in index coordinates.")
    required.add_argument(
        "-i", "--index",
        type=int, metavar="<index>",
        help="An index that will be associated to the current seeding point.")
    required.add_argument(
        "-a", "--t1-brain",
        type=is_file, metavar="<path>",
        help="The anatomical image used to generate the seeding mask: use the "
             "FreeSurfer file by default, or if different perform "
             "probabilistic mask with MRtrix.")

    # Optional arguments
    required.add_argument(
        "-P", "--t1-parc",
        type=is_file, metavar="<path>",
        help="Path to the parcellation that defines the nodes of the connec"
             "tome, e.g. aparc+aseg.mgz from FreeSurfer. It has to be in the "
             "FreeSurfer space (i.e. aligned with FreeSurfer 'brain.mgz').")
    required.add_argument(
        "-S", "--subject-id",
        metavar="<id>",
        help="Subject identifier.")
    parser.add_argument(
        "-D", "--subjects-dir", metavar="<path>",
        help="FreeSurfer subjects directory. To set or bypass the "
             "$SUBJECTS_DIR environment variable.")
    required.add_argument(
        "-B", "--nodif-brain",
        type=is_file, metavar="<path>",
        help="Diffusion brain-only volume to use to estimate "
             "the registration between diffusion and T1.")
    required.add_argument(
        "-K", "--keep-labels",
        type=int, nargs="+", metavar="<labels>",
        help="The labels to be extracted in the parcellation image, default "
             "is all values greater than zero.")
    parser.add_argument(
        "-A", "--nsamples",
        type=int, metavar="<int>", default=5000,
        help="The number of samples in probtrackx.")
    parser.add_argument(
        "-T", "--nsteps",
        type=int, metavar="<int>", default=2000,
        help="The number of steps per sample in probtrackx.")
    parser.add_argument(
        "-L", "--steplength",
        type=float, metavar="<float>", default=0.5,
        help="The propagation step in probtrackx.")
    parser.add_argument(
        "-M", "--sampvox",
        type=float, metavar="<float>", default=0.0,
        help="Random sampling sphere in probtrackx (in mm).")
    parser.add_argument(
        "-H", "--fs-sh",
        type=is_file, metavar="<path>",
        help="Bash script initializing FreeSurfer's environment.")
    parser.add_argument(
        "-F", "--fsl-sh",
        type=is_file, metavar="<path>",
        help="Bash script initializing FSL's environment.")
    parser.add_argument(
        "-v", "--verbose",
        dest="verbose", type=int, choices=[0, 1, 2], default=0,
        help="Increase the verbosity level: 0 silent, [1, 2] verbose.")
    args = parser.parse_args()

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    verbose = kwargs.pop("verbose")
    if kwargs["fs_sh"] is None:
        kwargs["fs_sh"] = DEFAULT_FREESURFER_PATH
    if kwargs["fsl_sh"] is None:
        kwargs["fsl_sh"] = DEFAULT_FSL_PATH

    return kwargs, verbose


"""
Parse the command line.
"""
inputs, verbose = get_cmd_line_args()
tool = "pyconnectome_probtrackx2_tractogram"
timestamp = datetime.now().isoformat()
tool_version = version
fsl_version = FSLWrapper([], shfile=inputs["fsl_sh"]).version
freesurfer_version = FSWrapper([], inputs["fs_sh"]).version
params = locals()
runtime = dict([(name, params[name])
               for name in ("tool", "tool_version", "fsl_version",
                            "timestamp", "freesurfer_version")])
outputs = None
if verbose > 0:
    print("[info] Starting FSL probtrackx2 tractogram...")
    print("[info] Runtime:")
    pprint(runtime)
    print("[info] Inputs:")
    pprint(inputs)


"""
Check the FSL bedpostx directory
"""
merged_prefix = os.path.join(inputs["bedpostxdir"], "merged")
merged_files = glob.glob(merged_prefix + "*")
if len(merged_files) == 0:
    raise ValueError("'{0}' is not a valid FSL bedpostx folder.".format(
        inputs["bedpostxdir"]))
nodifmask_file = os.path.join(inputs["bedpostxdir"], "nodif_brain_mask.nii.gz")
if not os.path.isfile(nodifmask_file):
    raise ValueError("'{0}' is not a valid FSL bedpostx folder.".format(
        inputs["bedpostxdir"]))

"""
Generate a mask in diffusion space with a rigid deformation.
"""
# Point case
if inputs["point"] is not None:

    # Do not genrate all the the probabily density masks
    opd = True

# Generate/align the seeding mask
elif inputs["t1_parc"] is None:

    # Do not genrate all the the probabily density masks
    opd = False

    # Check input
    if inputs["t1_brain"] is None:
        raise ValueError(
            "Please specify a parcellation or an antomical image.")

    # Generate a temporary folder
    tempdir = os.path.join(inputs["outdir"], "tmp")
    if not os.path.isdir(tempdir):
        os.mkdir(tempdir)
    else:
        print("[info] '{0}' folder already created.".format(tempdir))
    #gmwmi_mask_file = "/tmp/nsap/tractograms/fsl/nomask/gmwmi_mask.nii.gz"
    if 0:
        gmwmi_mask_file = white_matter_interface(
            t1_brain_file=inputs["t1_brain"],
            outdir=inputs["outdir"],
            tempdir=tempdir,
            fsl_sh=inputs["fsl_sh"])

    # Align the mask with the T1 space.
    anat2dif_file = os.path.join(inputs["outdir"], "anat2dif.txt")
    t1brain_to_dif_file = os.path.join(
        inputs["outdir"], "t1_brain_to_dif.nii.gz")
    gmwmi_mask_to_dif_file = os.path.join(
        inputs["outdir"], "gmwmi_mask_to_dif.nii.gz")
    if 0:
        flirt(
            in_file=inputs["t1_brain"],
            ref_file=inputs["nodif_brain"],
            omat=anat2dif_file,
            out=t1brain_to_dif_file,
            cost="normmi",
            interp="trilinear",
            dof=6,
            shfile=inputs["fsl_sh"])
        flirt(
            in_file=gmwmi_mask_file,
            ref_file=inputs["nodif_brain"],
            out=gmwmi_mask_to_dif_file,
            init=anat2dif_file,
            applyxfm=True,
            interp="spline",
            shfile=inputs["fsl_sh"])

    # Invert anat2dif transform
    dif2anat_file = os.path.join(inputs["outdir"], "dif2anat.txt")
    m = numpy.loadtxt(anat2dif_file)
    m_inv = numpy.linalg.inv(m)
    numpy.savetxt(dif2anat_file, m_inv)
    invxfm = dif2anat_file

elif["t1_parc"] is not None:

    # Do not genrate all the the probabily density masks
    opd = False

    # Align label in T1 space to diffusion space.
    t1_brain_to_dif, dif2anat_dat, _ = freesurfer_bbregister_t1todif(
        outdir=inputs["outdir"],
        subject_id=inputs["subject_id"],
        nodif_brain=inputs["nodif_brain"],
        subjects_dir=inputs["subjects_dir"],
        fs_sh=inputs["fs_sh"],
        fsl_sh=inputs["fsl_sh"])
    parc_name = (
        os.path.basename(inputs["t1_parc"]).split(".nii")[0].split(".mgz")[0])
    t1_parc_to_dif = os.path.join(
        inputs["outdir"], parc_name + "_to_dif.nii.gz")
    cmd = ["mri_vol2vol",
           "--mov",  inputs["nodif_brain"],
           "--targ", inputs["t1_parc"],
           "--inv",
           "--interp", "nearest",
           "--o",   t1_parc_to_dif,
           "--reg", dif2anat_dat,
           "--no-save-reg"]
    FSWrapper(cmd, shfile=inputs["fs_sh"])()

else:
    raise ValueError("Specify a seeding strategy.")


"""
Select the seeding strategy:
1- write single seeding point coordinates to file (-s -i options).
2- write seeding mask point coordinates to file (-a -p options).
3- write generated white interface point coordinates to file.
"""
seed_files = []
if inputs["point"] is not None:
    wdir = os.path.join(inputs["outdir"], "{0}".format(inputs["index"]))
    if not os.path.isdir(wdir):
        os.mkdir(wdir)
    else:
        print("[info] '{0}' folder already created.".format(wdir))
    seed_files.append(os.path.join(wdir, "fdt_coordinates.txt"))
    numpy.savetxt(seed_files[-1], inputs["point"])

else:

    # Get coordinates
    nb_seeds = 0
    if inputs["t1_parc"] is not None:
        chunk_size = 20000
        parc_im = nibabel.load(t1_parc_to_dif)
        parc_data = parc_im.get_data()
        affine = parc_im.get_affine()
        if inputs["keep_labels"] is None:
            seeds = numpy.argwhere(parc_data > 0)
        else:
            seeds_list = [numpy.argwhere(parc_data == label)
                          for label in inputs["keep_labels"]]
            seeds = numpy.concatenante(seed_list, axis=0)
        nb_seeds = len(seeds)
        seeds = numpy.array_split(seeds, numpy.mod(nb_seeds, chunk_size) + 1)
        labels = ["{0}_{1}".format(idx + 1, chunk_size)
                  for idx in range(nb_seeds)]
    else:
        gmwmi_mask_im = nibabel.load(gmwmi_mask_to_dif_file)
        gmwmi_mask_data = gmwmi_mask_im.get_data()
        affine = gmwmi_mask_im.get_affine()
        lower_thr = numpy.linspace(0, 0.95, 20)
        upper_thr = numpy.linspace(0.05, 1, 20)
        seeds = []
        labels = []
        for lower, upper in zip(lower_thr, upper_thr):
            part_seeds = numpy.argwhere(
                (gmwmi_mask_data > lower) & (gmwmi_mask_data <= upper))
            if len(part_seeds) != 0:
                seeds.append(part_seeds)
                nb_seeds += len(part_seeds)
                labels.append("{0}_{1}".format(lower, upper))

        # Do not consider p<0.1
        seeds = seeds[2:]
        labels = labels[2:]

    if verbose > 1:
        print("[info] Number of seeds: {0}.".format(nb_seeds))

    # Write coordinates
    dif_affine = nibabel.load(nodifmask_file).get_affine()
    trf = numpy.dot(affine, numpy.linalg.inv(dif_affine))
    for seeds_part, label in zip(seeds, labels):
        wdir = os.path.join(inputs["outdir"], label)
        if not os.path.isdir(wdir):
            os.mkdir(wdir)
        seed_files.append(os.path.join(wdir, "fdt_coordinates.txt")) 
        with open(seed_files[-1], "wt") as open_file:
            numpy.savetxt(open_file, seeds_part)
    

"""
Start the tractogram computation.
"""
for seed_file in seed_files:
    proba_files, network_file = probtrackx2(
        samples=merged_prefix,
        mask=nodifmask_file,
        seed=seed_file,
        nsamples=inputs["nsamples"],
        nsteps=inputs["nsteps"],
        steplength=inputs["steplength"],
        sampvox=inputs["sampvox"],
        simple=True,
        loopcheck=True,
        dir=os.path.dirname(seed_file),
        out="fdt_paths",
        seedref=nodifmask_file,
        onewaycondition=True,
        opd=opd,
        forcedir=True,
        savepaths=True,
        shfile=inputs["fsl_sh"])


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(wdir, "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
shutil.move(os.path.join(wdir, "probtrackx.log"),
            os.path.join(logdir, "probtrackx.log"))
params = locals()
outputs = dict([(name, params[name])
               for name in ("proba_files", "network_file")])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    print("[info] Outputs:")
    pprint(outputs)

