#! /usr/bin/env python
##########################################################################
# NSAp - Copyright (C) CEA, 2013 - 2017
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System import
from __future__ import print_function
import argparse
import os
import shutil
from datetime import datetime
import json
from pprint import pprint
import textwrap
from argparse import RawTextHelpFormatter

# Bredala import
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyconnectome.utils.segtools",
                     names=["bet2", "fast", "robustfov", "roi_from_bbox"])
    bredala.register("pyconnectome.utils.filetools",
                     names=["fslreorient2std"])
    bredala.register("pyconnectome.utils.regtools",
                     names=["flirt", "fnirt"])
    bredala.register("pydcmio.plotting.slicer",
                     names=["mosaic"])
except:
    pass

# Package import
from pyconnectome import __version__ as version
from pyconnectome import DEFAULT_FSL_PATH
from pyconnectome.wrapper import FSLWrapper
from pyconnectome.utils.segtools import bet2
from pyconnectome.utils.segtools import fast
from pyconnectome.utils.segtools import robustfov
from pyconnectome.utils.segtools import roi_from_bbox
from pyconnectome.utils.regtools import flirt
from pyconnectome.utils.regtools import fnirt
from pyconnectome.utils.filetools import fslreorient2std

# Third party import
import numpy
from pydcmio.plotting.slicer import mosaic


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


# Script documentation
DOC = """
FSL registration

Rigid/affine or non linear registration with FSL using flirt and fnirt.

Command:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_register \
    -v 2 \
    -s test_fsl \
    -o /volatile/nsap/recalage_cathy/results \
    -i /volatile/nsap/recalage_cathy/T1W_gado.nii.gz \
    -r /usr/share/fsl/data/standard/MNI152_T1_1mm.nii.gz \
    -j /usr/share/fsl/data/standard/MNI152_T1_1mm_brain.nii.gz \
    -x /volatile/nsap/recalage_cathy/mask_lesion.nii.gz /volatile/nsap/recalage_cathy/mask_necrosis.nii.gz \
    -c /etc/fsl/5.0/fsl.sh \
    -k 3 \
    -t 1 \
    -f 0.45 \
    -b \
    -a normmi \
    -n \
    -d


Command example for T1 to atlas affine registration with cropping on
the Metastasis data:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_register \
    -o /volatile/nsap/recalage_cathy/results \
    -s 585521174283 \
    -i /neurospin/radiomics/studies/metastasis/base/585521174283/anat/585521174283_enh-gado_T1w.nii.gz \
    -r /usr/share/fsl/data/standard/MNI152_T1_1mm.nii.gz \
    -O \
    -R 0 181 0 217 79 121 \
    -D \
    -Q /volatile/nsap/recalage_cathy/results/585521174283/brain.nii.gz \
    -S \
    -v 2

Command example for T1 to atlas non linear registration with cropping on
the Metastasis data:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_register \
    -o /volatile/nsap/recalage_cathy/results/nl \
    -s 585521174283 \
    -i /neurospin/radiomics/studies/metastasis/base/585521174283/anat/585521174283_enh-gado_T1w.nii.gz \
    -r /usr/share/fsl/data/standard/MNI152_T1_1mm.nii.gz \
    -O \
    -N \
    -J /usr/share/fsl/data/standard/MNI152_T1_1mm_brain.nii.gz \
    -R 0 181 0 217 79 121 \
    -D \
    -Q /volatile/nsap/recalage_cathy/results/585521174283/brain.nii.gz \
    -S \
    -v 2
"""


def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg


def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="pyconnectome_register",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-o", "--outdir",
        required=True, metavar="<path>", type=is_directory,
        help="the analysis output directory where <outdir>/<sid> will be "
             "generated.")
    required.add_argument(
        "-i", "--inputfile",
        required=True, type=is_file,
        help="the input MRI volume to be registered.")
    required.add_argument(
        "-r", "--referencefile", 
        required=True, type=is_file,
        help="the template file used as a reference volume during the linear "
             "registration.")

    # Optional arguments
    parser.add_argument(
        "-s", "--sid",
        help="the subject identifier.")
    parser.add_argument(
        "-C", "--nosearch", 
        action="store_true",
        help="if set, perform no search to initializa the linear "
             "optimization.")
    parser.add_argument(
        "-P", "--dof",
        type=int, default=12, choices=[6, 12],
        help="the number of DOF in the linear registration.")    
    parser.add_argument(
        "-S", "--dosnap",
        action="store_true",
        help="if activated, generate QC snaps.")
    parser.add_argument(
        "-R", "--bbox",
        type=int, nargs=6, metavar="<INT>",
        help="a bounding box that will be wraped into the reference space.")
    parser.add_argument(
        "-E", "--erase",
        action="store_true",
        help="if activated, clean the result folder if already created.")
    parser.add_argument(
        "-X", "--extrafiles",
        type=is_file, nargs="*",
        help="list of files to apply transform to.")
    parser.add_argument(
        "-D", "--dotissues",
        action="store_true",
        help="if activated, segment the scan tissues with FAST and apply "
             "spatial intensity variations correction.")
    parser.add_argument(
        "-T", "--intype",
        default=1, type=int, choices=range(1, 4),
        help="type of the input image 1=T1, 2=T2, 3=PD")
    parser.add_argument(
        "-K", "--nbklass",
        default=3, type=int,
        help="the number of class in the FAST modelization.")
    parser.add_argument(
        "-Z", "--doreorient",
        action="store_true",
        help="if set, use the FSL routine to reorient input images.")
    parser.add_argument(
        "-O", "--docrop",
        action="store_true",
        help="if set, reduce the FOV of the input image to remove lower head "
             "and neck.")
    parser.add_argument(
        "-W", "--cropwho",
        default="moving", choices=("moving", "reference", "both"),
        help="select which input images will be cropped.")
    parser.add_argument(
        "-I", "--brainsize",
        type=int, default=170,
        help="the brain size used during the cropping.")
    parser.add_argument(
        "-B", "--dobrain",
        action="store_true",
        help="if set, use the BET2 routine to segment the subject brain.")
    parser.add_argument(
        "-Q", "--brain",
        type=is_file,
        help="the subject brain.")
    parser.add_argument(
        "-H", "--brainthr", dest="brainthr",
        default=0.5, type=float,
        help="the BET2 fractional intensity threshold (0->1).")
    parser.add_argument(
        "-A", "--cost",
        default="normmi",
        choices=("mutualinfo", "corratio", "normcorr", "normmi", "leastsq",
                 "labeldiff", "bbr"),
        help="the affine cost function type.")
    parser.add_argument(
        "-L", "--interp",
        default="spline",
        choices=("spline", "nearestneighbour", "trilinear"),
        help="the interpolation used when applying the deformation to extra "
             "files.")
    parser.add_argument(
        "-N", "--dononlinear", 
        action="store_true",
        help="if set, use the FNIRT routine to align the subject brain to the "
             "template with a non linear transformation.")
    parser.add_argument(
        "-J", "--referencebrainfile",
        type=is_file,
        help="the template brain file used as a reference volume during the "
             "non linear registration.")
    parser.add_argument(
        "-G", "--wmseg",
        type=is_file,
        help="the white matter segementation file needed when BBR cost is "
             "used.")
    parser.add_argument(
        "-F", "--fsl-sh",
        type=is_file, metavar="<path>",
        help="bash script initializing FSL's environment.")
    parser.add_argument(
        "-v", "--verbose",
        type=int, choices=[0, 1, 2], default=0,
        help="increase the verbosity level: 0 silent, [1, 2] verbose.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    verbose = kwargs.pop("verbose")
    if kwargs["fsl_sh"] is None:
        kwargs["fsl_sh"] = DEFAULT_FSL_PATH

    return kwargs, verbose


"""
Parse the command line.
"""
inputs, verbose = get_cmd_line_args()
tool = "pyconnectome_register"
timestamp = datetime.now().isoformat()
tool_version = version
fsl_version = FSLWrapper([], shfile=inputs["fsl_sh"]).version
params = locals()
runtime = dict([(name, params[name])
               for name in ("tool", "tool_version", "fsl_version",
                            "timestamp")])
if inputs["sid"] is not None:
    subjdir = os.path.join(inputs["outdir"], inputs["sid"])
else:
    subjdir = inputs["outdir"]
if inputs["erase"] and os.path.isdir(subjdir):
    shutil.rmtree(subjdir)
if not os.path.isdir(subjdir):
    os.mkdir(subjdir)
outputs = None
snaps = []
if inputs["dosnap"]:
    snap_dir = os.path.join(subjdir, "snap")
    if not os.path.isdir(snap_dir):
        os.mkdir(snap_dir)
if verbose > 0:
    pprint("[info] Starting registration ...")
    pprint("[info] Runtime:")
    pprint(runtime)
    pprint("[info] Inputs:")
    pprint(inputs)


"""
Step -1: Reorient images.
"""
if inputs["doreorient"]:
    reorient_dir = os.path.join(subjdir, "reorient")
    if not os.path.isdir(reorient_dir):
        os.mkdir(reorient_dir)
    inputfile = os.path.join(
        reorient_dir, os.path.basename(inputs["inputfile"]))
    inputfile = fslreorient2std(
        inputs["inputfile"], inputfile, fslconfig=inputs["fsl_sh"])
    referencefile = os.path.join(
        reorient_dir, os.path.basename(inputs["referencefile"]))
    referencefile = fslreorient2std(
        inputs["referencefile"], referencefile, fslconfig=inputs["fsl_sh"])
else:
    inputfile = inputs["inputfile"]
    referencefile = inputs["referencefile"]
if verbose > 0:
    print("[result] Input image: {0}.".format(inputfile))
    print("[result] Reference image: {0}.".format(referencefile))


"""
Step 0: Remove lower head and neck.
"""
crop_outputs = [None, None]
crop_outputs_und = [None, None]
crop_outputs_trf = [None, None]
if inputs["docrop"]:
    
    # Select step inputs
    step_inputs = [None, None]
    if inputs["cropwho"] in ("moving", "both"):
        step_inputs[0] = inputfile
    if inputs["cropwho"] in ("reference", "both"):
        step_inputs[1] = referencefile

    # Go through each input
    for cnt, path in enumerate(step_inputs):

        # Stop cas
        if path is None:
            continue

        # Crop
        basename = os.path.basename(path)
        crop_dir = os.path.join(subjdir, "robustfov")
        if not os.path.isdir(crop_dir):
            os.mkdir(crop_dir)
        cropped_file = os.path.join(crop_dir, "robustfov_" + basename)
        cropped_trf = os.path.join(
            crop_dir, "robustfov_" + basename.split(".")[0] + ".txt")
        robustfov(
            input_file=path,
            output_file=cropped_file,
            brain_size=inputs["brainsize"],
            matrix_file=cropped_trf,
            fsl_sh=inputs["fsl_sh"])
        crop_outputs[cnt] = cropped_file
        crop_outputs_trf[cnt] = cropped_trf

        # Restore orginial shape
        cropped_und_file = os.path.join(crop_dir, "robustfov_und_" + basename)
        cropped_und_file, _ = flirt(
            in_file=cropped_file,
            ref_file=path,
            out=cropped_und_file,
            init=cropped_trf,
            applyxfm=True,
            verbose=verbose,
            shfile=inputs["fsl_sh"])
        crop_outputs[cnt] = cropped_und_file

        # Perform QC
        if inputs["dosnap"]:
            snaps.append(mosaic(impath=cropped_file,
                                title="robustfov_{0}".format(cnt),
                                outdir=snap_dir))
if crop_outputs[0] is not None:
    cropped_moving_file = crop_outputs[0]
else:
    cropped_moving_file = inputfile
if crop_outputs[1] is not None:
    cropped_reference_file = crop_outputs[1]
else:
    cropped_reference_file = referencefile
if verbose > 0:
    print("[result] Input cropped image: {0}.".format(cropped_moving_file))
    print("[result] Reference cropped image: {0}.".format(
        cropped_reference_file))


"""
Step 1: Brain extraction
"""
if inputs["brain"] is not None:
    mask_file, mesh_file, outline_file, inskull_mask_file = (None, ) * 4
    nskull_mesh_file, outskull_mesh_file, outskin_mask_file = (None, ) * 3
    skull_mask_file, outskin_mesh_file, outskull_mask_file = (None, ) * 3
    brain_file = inputs["brain"]
    # No BET, use full FOV
    if inputs["docrop"] and inputs["cropwho"] in ("moving", "both"):
        cropped_moving_file = crop_outputs_und[0]
elif inputs["dobrain"]:
    bet_dir = os.path.join(subjdir, "bet")
    if not os.path.isdir(bet_dir):
        os.mkdir(bet_dir)
    (brain_file, mask_file, mesh_file, outline_file, inskull_mask_file,
     inskull_mesh_file, outskull_mask_file, outskull_mesh_file,
     outskin_mask_file, outskin_mesh_file, skull_mask_file) = bet2(
        input_file=cropped_moving_file,
        output_fileroot=bet_dir,
        outline=False,
        mask=True,
        skull=False,
        nooutput=False,
        f=inputs["brainthr"],
        g=0,
        radius=None,
        smooth=None,
        c=None,
        threshold=False,
        mesh=False,
        shfile=inputs["fsl_sh"])
    if inputs["dosnap"]:
        snaps.append(mosaic(impath=brain_file,
                            title="bet",
                            outdir=snap_dir))
else:
    mask_file, mesh_file, outline_file, inskull_mask_file = (None, ) * 4
    nskull_mesh_file, outskull_mesh_file, outskin_mask_file = (None, ) * 3
    skull_mask_file, outskin_mesh_file, outskull_mask_file = (None, ) * 3
    brain_file =  cropped_moving_file
if verbose > 0:
    print("[result] Brain image: {0}.".format(brain_file))
    print("[result] Brain mask image: {0}.".format(mask_file))


"""
Step 2: Tissue segmentation and spatial intensity variations correction.
"""
if inputs["dotissues"]:
    fast_dir = os.path.join(subjdir, "fast")
    if not os.path.isdir(fast_dir):
        os.mkdir(fast_dir)
    fast_fileroot = os.path.join(
        fast_dir, os.path.basename(inputs["inputfile"]).split(".")[0])
    tpm, tsm, segmentation_anatfile, bias_anatfile, biascorrected_anatfile = fast(
        input_file=brain_file,
        out_fileroot=fast_fileroot,
        klass=inputs["nbklass"],
        im_type=inputs["intype"],
        segments=True,
        bias_field=True,
        bias_corrected_im=True,
        probabilities=True,
        shfile=inputs["fsl_sh"])
    if inputs["dosnap"]:
        snaps.append(mosaic(impath=segmentation_anatfile,
                            title="fast_segmentation",
                            outdir=snap_dir))
else:
    tpm, tsm, segmentation_anatfile, bias_anatfile = (None, None, None, None)
    biascorrected_anatfile = brain_file
if verbose > 0:
    print("[result] Antomical tissues: {0}.".format(segmentation_anatfile))
    print("[result] Antomical bias field: {0}.".format(bias_anatfile))
    print("[result] Antomical bias corrected: {0}.".format(
        biascorrected_anatfile))


"""
Step 3: Affine registration
"""
basename = os.path.basename(cropped_moving_file).split(".")[0]
flirt_dir = os.path.join(subjdir, "flirt")
if not os.path.isdir(flirt_dir):
    os.mkdir(flirt_dir)
affine_file, affine_omat = flirt(
    in_file=biascorrected_anatfile,
    ref_file=cropped_reference_file,
    omat=os.path.join(flirt_dir, "flirt_{0}_to_template.txt".format(basename)),
    out=os.path.join(flirt_dir, "flirt_{0}.nii.gz".format(basename)),
    init=None,
    cost=inputs["cost"],
    usesqform=False,
    displayinit=False,
    anglerep="euler",
    bins=256,
    interp="trilinear",
    dof=inputs["dof"],
    applyxfm=False,
    applyisoxfm=None,
    nosearch=inputs["nosearch"],
    wmseg=inputs["wmseg"],
    verbose=verbose,
    shfile=inputs["fsl_sh"])
if inputs["dosnap"]:
    snaps.append(mosaic(impath=cropped_reference_file,
                        overlay=affine_file,
                        title="affine_registration",
                        outdir=snap_dir))
if verbose > 0:
    print("[result] Affine: {0}.".format(affine_file))
    print("[result] Affine transformation: {0}.".format(affine_omat))


"""
Step 4: Non linear registration
"""
if inputs["dononlinear"]:
    fnirt_dir = os.path.join(subjdir, "fnirt")
    if not os.path.isdir(fnirt_dir):
        os.mkdir(fnirt_dir)
    cout, iout, fout, jout, refout, intout, logout = fnirt(
        in_file=biascorrected_anatfile,
        ref_file=cropped_reference_file,
        affine_file=affine_omat,
        outdir=fnirt_dir,
        inmask_file=None,
        verbose=verbose,
        shfile=inputs["fsl_sh"])
    if inputs["dosnap"]:
        snaps.append(mosaic(impath=cropped_reference_file,
                            overlay=iout,
                            title="nonlinear_registration",
                            outdir=snap_dir))
else:
    cout, iout, fout, jout, refout, intout, logout = (None, ) * 7


"""
Step 5: Create bounding box
"""
if inputs["bbox"] is not None:

    # Create bbox
    bbox_dir = os.path.join(subjdir, "bbox")
    if not os.path.isdir(bbox_dir):
        os.mkdir(bbox_dir)
    bbox_file = os.path.join(bbox_dir, "bbox.nii.gz")
    roi_from_bbox(
        input_file=cropped_reference_file,
        bbox=inputs["bbox"],
        output_file=bbox_file)

    # Invert anat2dif transform
    inv_affine_omat = os.path.join(
        bbox_dir, "flirt_template_to_{0}.txt".format(basename))
    m = numpy.loadtxt(affine_omat)
    m_inv = numpy.linalg.inv(m)
    numpy.savetxt(inv_affine_omat, m_inv)

    # Send bbox to native space
    native_bbox_file = os.path.join(bbox_dir, "native_bbox.nii.gz")
    native_bbox_file, _ = flirt(
        in_file=bbox_file,
        ref_file=inputs["inputfile"],
        out=native_bbox_file,
        init=inv_affine_omat,
        applyxfm=True,
        interp="nearestneighbour",
        verbose=verbose,
        shfile=inputs["fsl_sh"])

    # Snap
    if inputs["dosnap"]:
        snaps.append(mosaic(impath=inputs["inputfile"],
                            overlay=native_bbox_file,
                            title="bbox native",
                            outdir=snap_dir))

else:
    bbox_file = None
    native_bbox_file = None


"""
Step 6: Apply deformation to extra files
"""
extramnifiles = []
if inputs["extrafiles"] is not None:
    extra_dir = os.path.join(subjdir, "extra")
    if not os.path.isdir(extra_dir):
        os.mkdir(extra_dir)
    for path in inputs["extrafiles"]:
        if inputs["dononlinear"]:
            raise NotImplementedError(
                "Use applywarp -r MNI152_T1_1mm.nii -i myvolume.nii "
                "-o myvolumeInMNI.nii -w warp_struct2mni.nii "
                "--premat=func2struct.mat")
        else:
            path_mni_file = os.path.join(extra_dir, os.path.basename(path))
            path_mni_file, _ = flirt(
                in_file=path,
                ref_file=cropped_reference_file,
                out=path_mni_file,
                init=affine_omat,
                applyxfm=True,
                interp=inputs["interp"],
                verbose=verbose,
                shfile=inputs["fsl_sh"])
            extramnifiles.append(path_mni_file)              


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(subjdir, "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
params = locals()
outputs = dict([(name, params[name])
               for name in ("cropped_moving_file", "cropped_reference_file",
                            "crop_outputs_trf", "tpm", "tsm",
                            "segmentation_anatfile", "bias_anatfile",
                            "biascorrected_anatfile", "brain_file", "mask_file",
                            "affine_file", "affine_omat", "cout", "iout",
                            "fout", "jout", "refout", "intout", "logout",
                            "native_bbox_file", "bbox_file", "extramnifiles",
                            "snaps")])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    print("[info] Outputs:")
    pprint(outputs)
