#! /usr/bin/env python
##########################################################################
# NSAp - Copyright (C) CEA, 2013 - 2017
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System import
from __future__ import print_function
import argparse
import os
import shutil
from datetime import datetime
import json
from pprint import pprint
import textwrap
from argparse import RawTextHelpFormatter

# Bredala import
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyconnectome.utils.segtools",
                     names=["bet2", "fast", "robustfov"])
    bredala.register("pydcmio.plotting.slicer",
                     names=["mosaic"])
    bredala.register("pyconnectome.utils.regtools",
                     names=["flirt"])
except:
    pass

# Package import
from pyconnectome import __version__ as version
from pyconnectome import DEFAULT_FSL_PATH
from pyconnectome.wrapper import FSLWrapper
from pyconnectome.utils.segtools import bet2
from pyconnectome.utils.segtools import fast
from pyconnectome.utils.segtools import robustfov
from pyconnectome.utils.regtools import flirt

# Third party import
import numpy
from pydcmio.plotting.slicer import mosaic


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


# Script documentation
DOC = """
FSL tissue segmentation

Tissue segmentation with FSL FAST.

Command example for T1 to atlas affine registration with cropping on
the Metastasis data:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_tissue_segmentation \
    -o /volatile/nsap/recalage_cathy/segmentation \
    -s 585521174283 \
    -i /volatile/nsap/recalage_cathy/results/585521174283/brain.nii.gz \
    -S \
    -v 2
"""


def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg


def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="pyconnectome_tissue_segmentation",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-o", "--outdir",
        required=True, metavar="<path>", type=is_directory,
        help="the analysis output directory where <outdir>/<sid> will be "
             "generated.")
    required.add_argument(
        "-s", "--sid",
        required=True,
        help="the subject identifier.")
    required.add_argument(
        "-i", "--inputfile",
        required=True, type=is_file,
        help="the input MRI volume to be segmented.")

    # Optional arguments
    parser.add_argument(
        "-S", "--dosnap",
        action="store_true",
        help="if activated, generate QC snaps.")
    parser.add_argument(
        "-T", "--intype",
        default=1, type=int, choices=range(1, 4),
        help="type of the input image 1=T1, 2=T2, 3=PD")
    parser.add_argument(
        "-K", "--nbklass",
        default=3, type=int,
        help="the number of class in the FAST modelization.")
    parser.add_argument(
        "-O", "--docrop",
        action="store_true",
        help="if set, reduce the FOV of the input image to remove lower head "
             "and neck.")
    parser.add_argument(
        "-I", "--brainsize",
        type=int, default=170,
        help="the brain size used during the cropping.")
    parser.add_argument(
        "-B", "--dobrain",
        action="store_true",
        help="if set, use the BET2 routine to segment the subject brain.")
    parser.add_argument(
        "-H", "--brainthr", dest="brainthr",
        default=0.5, type=float,
        help="the BET2 fractional intensity threshold (0->1).")
    parser.add_argument(
        "-F", "--fsl-sh",
        type=is_file, metavar="<path>",
        help="bash script initializing FSL's environment.")
    parser.add_argument(
        "-v", "--verbose",
        type=int, choices=[0, 1, 2], default=0,
        help="increase the verbosity level: 0 silent, [1, 2] verbose.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    verbose = kwargs.pop("verbose")
    if kwargs["fsl_sh"] is None:
        kwargs["fsl_sh"] = DEFAULT_FSL_PATH

    return kwargs, verbose


"""
Parse the command line.
"""
inputs, verbose = get_cmd_line_args()
tool = "pyconnectome_tissue_segmentation"
timestamp = datetime.now().isoformat()
tool_version = version
fsl_version = FSLWrapper([], shfile=inputs["fsl_sh"]).version
params = locals()
runtime = dict([(name, params[name])
               for name in ("tool", "tool_version", "fsl_version",
                            "timestamp")])
subjdir = os.path.join(inputs["outdir"], inputs["sid"])
if not os.path.isdir(subjdir):
    os.mkdir(subjdir)
outputs = None
snaps = []
if inputs["dosnap"]:
    snap_dir = os.path.join(subjdir, "snap")
    if not os.path.isdir(snap_dir):
        os.mkdir(snap_dir)
if verbose > 0:
    pprint("[info] Starting registration ...")
    pprint("[info] Runtime:")
    pprint(runtime)
    pprint("[info] Inputs:")
    pprint(inputs)


"""
Step 0: Remove lower head and neck.
"""
if inputs["docrop"]:
    # Crop
    basename = os.path.basename(inputs["inputfile"])
    crop_dir = os.path.join(subjdir, "robustfov")
    if not os.path.isdir(crop_dir):
        os.mkdir(crop_dir)
    cropped_file = os.path.join(crop_dir, "robustfov_" + basename)
    cropped_trf = os.path.join(
        crop_dir, "robustfov_" + basename.split(".")[0] + ".txt")
    robustfov(
        input_file=inputs["inputfile"],
        output_file=cropped_file,
        brain_size=inputs["brainsize"],
        matrix_file=cropped_trf,
        fsl_sh=inputs["fsl_sh"])

    # Restore orginial shape
    cropped_und_file = os.path.join(crop_dir, "robustfov_und_" + basename)
    cropped_und_file, _ = flirt(
        in_file=cropped_file,
        ref_file=inputs["inputfile"],
        out=cropped_und_file,
        init=cropped_trf,
        applyxfm=True,
        interp="spline",
        verbose=verbose,
        shfile=inputs["fsl_sh"])

    # Perform QC
    if inputs["dosnap"]:
        snaps.append(mosaic(impath=cropped_file,
                            title="robustfov",
                            outdir=snap_dir))
else:
    cropped_file = inputs["inputfile"]
    cropped_trf = None


"""
Step 1: Brain extraction
"""
if inputs["dobrain"]:
    bet_dir = os.path.join(subjdir, "bet")
    if not os.path.isdir(bet_dir):
        os.mkdir(bet_dir)
    (brain_file, mask_file, mesh_file, outline_file, inskull_mask_file,
     inskull_mesh_file, outskull_mask_file, outskull_mesh_file,
     outskin_mask_file, outskin_mesh_file, skull_mask_file) = bet2(
        input_file=cropped_file,
        output_fileroot=bet_dir,
        outline=False,
        mask=True,
        skull=False,
        nooutput=False,
        f=inputs["brainthr"],
        g=0,
        radius=None,
        smooth=None,
        c=None,
        threshold=False,
        mesh=False,
        shfile=inputs["fsl_sh"])
    if inputs["dosnap"]:
        snaps.append(mosaic(impath=brain_file,
                            title="bet",
                            outdir=snap_dir))
else:
    mask_file, mesh_file, outline_file, inskull_mask_file = (None, ) * 4
    nskull_mesh_file, outskull_mesh_file, outskin_mask_file = (None, ) * 3
    skull_mask_file, outskin_mesh_file, outskull_mask_file = (None, ) * 3
    brain_file =  cropped_file
if verbose > 0:
    print("[result] Brain image: {0}.".format(brain_file))
    print("[result] Brain mask image: {0}.".format(mask_file))


"""
Step 2: Tissue segmentation and spatial intensity variations correction.
"""
fast_dir = os.path.join(subjdir, "fast")
if not os.path.isdir(fast_dir):
    os.mkdir(fast_dir)
fast_fileroot = os.path.join(
    fast_dir, os.path.basename(inputs["inputfile"]).split(".")[0])
tpm, tsm, segmentation_anatfile, bias_anatfile, biascorrected_anatfile = fast(
    input_file=brain_file,
    out_fileroot=fast_fileroot,
    klass=inputs["nbklass"],
    im_type=inputs["intype"],
    segments=True,
    bias_field=True,
    bias_corrected_im=True,
    probabilities=True,
    shfile=inputs["fsl_sh"])
if inputs["dosnap"]:
    snaps.append(mosaic(impath=segmentation_anatfile,
                        title="fast_segmentation",
                        outdir=snap_dir))
if verbose > 0:
    print("[result] Antomical tissues: {0}.".format(segmentation_anatfile))
    print("[result] Antomical bias field: {0}.".format(bias_anatfile))
    print("[result] Antomical bias corrected: {0}.".format(
        biascorrected_anatfile))


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(subjdir, "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
params = locals()
outputs = dict([(name, params[name])
               for name in ("cropped_file", "cropped_trf", "tpm", "tsm",
                            "segmentation_anatfile", "bias_anatfile",
                            "biascorrected_anatfile", "brain_file", "mask_file",
                            "snaps")])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    print("[info] Outputs:")
    pprint(outputs)
