#! /usr/bin/env python
# -*- coding: utf-8 -*
##########################################################################
# NSAp - Copyright (C) CEA, 2017
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html for details.
##########################################################################

# System import
import os
import argparse
import json
from pprint import pprint
from datetime import datetime
import textwrap
from argparse import RawTextHelpFormatter

# Bredala module
try:
    import bredala
    bredala.USE_PROFILER = False
    #bredala.register("pyconnectome.tractography.filtering",
    #                 names=["life"])
except:
    pass

# Package import
from pyconnectome import __version__ as version


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


DOC = """
Registration with ANTS
----------------------

Can perform an Affine or Affine+NL registration with ANTS.

Command example on the SENIOR data - monomodal:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_ants_register \
    -b /usr/lib/ants \
    -o /volatile/nsap/senior/ants \
    -i /neurospin/senior/nsap/data/V0/nifti/nc140436/000004_3DT1/3DT1.nii.gz \
    -r /neurospin/nsap/processed/senior_t2star/data/fbrain/senior/ants/T_template0.nii.gz \
    -w 1 \
    -D 3 \
    -G 0.1 \
    -J 6 \
    -N \
    -v 2

Command example on the SENIOR data - multimodal:

python $HOME/git/pyconnectome/pyconnectome/scripts/pyconnectome_ants_register \
    -b /usr/lib/ants \
    -o /volatile/nsap/senior/ants \
    -i /neurospin/senior/nsap/data/V0/nifti/nc140436/000004_3DT1/3DT1.nii.gz /neurospin/senior/nsap/data/V0/nifti/nc140436/000004_3DT1/3DT1.nii.gz \
    -r /neurospin/nsap/processed/senior_t2star/data/fbrain/senior/ants/T_template0.nii.gz /neurospin/nsap/processed/senior_t2star/data/fbrain/senior/ants/T_template0.nii.gz \
    -w 0.5 0.2 \
    -D 3 \
    -G 0.2 \
    -J 6 \
    -N \
    -v 2
"""


def is_file(filepath):
    """ Check file's existence - argparse 'type' argument.
    """
    if not os.path.isfile(filepath):
        raise argparse.ArgumentError("File does not exist: %s" % filepath)
    return filepath

def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="python pyconnectome_ants_template",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-o", "--outdir",
        type=is_directory, required=True, metavar="<path>",
        help="Directory where to output.")
    required.add_argument(
        "-b", "--binaries",
        type=is_directory, required=True, metavar="<path>",
        help="Path to the ANTS binaries.")
    required.add_argument(
        "-i", "--images",
        type=is_file, required=True, metavar="<path>", nargs="+",
        help="Path to the input images. One can specify more than one file "
             "for for multi-modal registration (e.g. t1 and t2)..")
    required.add_argument(
        "-r", "--references",
        type=is_file, required=True, metavar="<path>",  nargs="+",
        help="Path to the reference images. One can specify more than one "
             "file for for multi-modal registration (e.g. t1 and t2)..")
    required.add_argument(
        "-w", "--weights",
        type=float, required=True, nargs="+",
        help="Modality weights used in the similarity metric.")

    # Optional arguments
    parser.add_argument(
        "-C", "--cross-modality",
        action="store_true", default=False,
        help="If set perform a cross modality registration.")
    parser.add_argument(
        "-B", "--nobfc",
        action="store_true", default=False,
        help="If set no bias field correction is performed.")
    parser.add_argument(
        "-D", "--dimensions",
        type=int, choices=[2, 3, 4], default=3,
        help="The images dimensions.")
    parser.add_argument(
        "-N", "--dononlinear", 
        action="store_true",
        help="if set, use the SyN routine to align the subject brain to the "
             "template with a non linear transformation.")
    parser.add_argument(
        "-G", "--gradient-step",
        type=float, default=0.25,
        help="Smaller in magnitude results in more cautious steps. Use "
             "smaller steps to refine template details. 0.25 is an upper "
             "(aggressive) limit for this parameter.")
    parser.add_argument(
        "-J", "--nbthread",
        type=int, default=2,
        help="Number of cpu cores to use locally.")
    parser.add_argument(
        "-R", "--restrict-deformation",
        type=int, nargs="*",
        help="This option allows the user to restrict the optimization of the "
             "displacement field.")
    parser.add_argument(
        "-L", "--moving-mask",
        type=is_file, metavar="<path>",
        help="Path to a mask image in the moving space: restrict "
             "registration computations only to voxels with the mask (need to "
             "specify the reference-mask).")
    parser.add_argument(
        "-A", "--reference-mask",
        type=is_file, metavar="<path>",
        help="Path to a mask image in the reference space: restrict "
             "registration computations only to voxels with the mask.")
    parser.add_argument(
        "-T", "--apply-to",
        metavar="<paths>", nargs="+",
        help="Apply the computed deformation on these images: "
             "[path,apply_inverse,interp] with apply_inverse (0, 1) and "
             "interp (0:nearest neighboor, 1:bspline).")
    parser.add_argument(
        "-V", "--verbose",
        type=int, choices=[0, 1, 2], default=2,
        help="Increase the verbosity level: 0 silent, [1, 2] verbose.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    verbose = kwargs.pop("verbose")

    return kwargs, verbose


"""
Parse the command line.
"""
inputs, verbose = get_cmd_line_args()
tool = "pyconnectome_ants_registration"
timestamp = datetime.now().isoformat()
tool_version = version
params = locals()
runtime = dict([(name, params[name])
               for name in ("tool", "tool_version", "timestamp")])
outputs = None
if verbose > 0:
    pprint("[info] Starting ANTS registration...")
    pprint("[info] Runtime:")
    pprint(runtime)
    pprint("[info] Inputs:")
    pprint(inputs)


"""
Start registration.
"""
import numpy
import subprocess


def to_ants_list(l):
    """ Transform a Python list to an ANTS list.

    Parameters
    ----------
    l: list
        a Python list.

    Returns
    -------
    a: str
        an 'x' separated ANTS list.
    """
    ll = [str(e) for e in l]
    return "x".join(ll)


# Check inputs
if len(inputs["images"]) != len(inputs["weights"]):
    raise ValueError("Number of modalities mismatched in inputs.")
if len(inputs["images"]) != len(inputs["references"]):
    raise ValueError("Number of references mismatched number of inputs.")


# Define parameters
interp_map = {
    "0": "NearestNeighbor",
    "1": "BSpline"}
nb_modalities = len(inputs["images"])
transforms = ["Affine", "SyN"]
radius_or_number_of_bins = [32, 4]
sampling_percentage = [0.25, None]
sampling_strategy = ["Regular", None]
convergence_threshold = [1.e-6, 1.e-9]
convergence_window_size = [10, 10]
number_of_iterations = [[1000, 500, 250, 100], [100, 100, 70, 20]]
transform_parameters = [(inputs["gradient_step"], ),
                        (inputs["gradient_step"], 3.0, 0.0)]
shrink_factors = [[6, 4, 2, 1], [6, 4, 2, 1]]
smoothing_sigmas = [[4, 2, 1, 0], [4, 2, 1, 0]]
metric_weight = [inputs["weights"], inputs["weights"]]
if inputs["restrict_deformation"] is not None:
    restrict_deformation = inputs["restrict_deformation"]
else:
    restrict_deformation = [1, 1, 1]
baseoutput = os.path.join(inputs["outdir"], "ants_")


# Do bias field corrections
# -x: mask image
if not inputs["nobfc"]:
    print("Starting bias field corrections...")
    for index in range(nb_modalities):
        biascorrectedfile = os.path.join(
            inputs["outdir"], "ants_BFC{0}.nii.gz".format(index))
        cmd0 = [
            "N4BiasFieldCorrection",
            "-d", str(inputs["dimensions"]),
            "-b", "[200]",  # b-spline fitting parameters as the isotropic sizing of the mesh elements
            "-c", "[50x50x40x30, 0.00000001]",  # convergence as number of iterations at each resolution and convergence threshold
            "-i", inputs["images"][index],
            "-o", biascorrectedfile,
            "-r", "0",  # no intensity resacling
            "-s", "2",  # shrink factor
            "-v", "1" if verbose > 0 else "0"]
        if verbose > 0:
            print("-" * 20)
            print(" ".join(cmd0))
            print("-" * 20)
        environment = os.environ
        environment["ANTSPATH"] = inputs["binaries"]
        environment["PATH"] += ":" + inputs["binaries"]
        process = subprocess.Popen(cmd0,
                                   env=environment,
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)
        stdout, stderr = process.communicate()
        exitcode = process.returncode
        if exitcode != 0:
            raise ValueError("Command '{0}' failed:: {1}".format(
                " ".join(cmd0), stderr + stdout))
        inputs["images"][index] = biascorrectedfile
    print("Done.")
else:
    cmd0 = None


# Basic registration options
print("Starting registration...")
if inputs["cross_modality"]:
    use_histogram_matching = "0"
else:
    use_histogram_matching = "1"
cmd1 = [
    "antsRegistration",
    "-d", str(inputs["dimensions"]),
    "--float", "1",
    "-u", use_histogram_matching,
    "-n", "Linear",
    "-w", "[0.005,0.995]",
    "-z", "1",
    "-r", "[{0},{1},1]".format(inputs["references"][0], inputs["images"][0]),
    "-o", baseoutput]
# Affine options
cmd1 += ["-t", "Affine[{0}]".format(inputs["gradient_step"])]
for index in range(nb_modalities):
    cmd1 += ["-m", "MI[{0},{1},{2},{3},{4},{5}]".format(
        inputs["references"][index],
        inputs["images"][index],
        metric_weight[0][index],
        radius_or_number_of_bins[0],
        sampling_strategy[0],
        sampling_percentage[0])]
cmd1 += [
    "-c", "[{0},{1},{2}]".format(
        to_ants_list(number_of_iterations[0]),
        convergence_threshold[0],
        convergence_window_size[0]),
    "-f", to_ants_list(shrink_factors[0]),
    "-s", to_ants_list(smoothing_sigmas[0])]
# NL options
if inputs["dononlinear"]:
    cmd1 += ["-t", "SyN[{0},3,0]".format(inputs["gradient_step"])]
    for index in range(nb_modalities):
        cmd1 += ["-m", "CC[{0},{1},{2},{3}]".format(
            inputs["references"][index],
            inputs["images"][index],
            metric_weight[1][index],
            radius_or_number_of_bins[1])]
    cmd1 += [
        "-c", "[{0},{1},{2}]".format(
            to_ants_list(number_of_iterations[1]),
            convergence_threshold[1],
            convergence_window_size[1]),
        "-f", to_ants_list(shrink_factors[1]),
        "-s", to_ants_list(smoothing_sigmas[1]),
        "-g", to_ants_list(restrict_deformation)]
if inputs["moving_mask"] is not None or inputs["reference_mask"] is not None:
    cmd1 += [
        "-x", "[{0}, {1}]".format(inputs.get("reference_mask", ""),
                                  inputs.get("moving_mask", ""))]  
if verbose > 0:
    print("-" * 20)
    print(" ".join(cmd1))
    print("-" * 20)
environment = os.environ
environment["ANTSPATH"] = inputs["binaries"]
environment["PATH"] += ":" + inputs["binaries"]
process = subprocess.Popen(cmd1,
                           env=environment,
                           stdout=subprocess.PIPE,
                           stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
exitcode = process.returncode
if exitcode != 0:
    raise ValueError("Command '{0}' failed:: {1}".format(
        " ".join(cmd1), stderr + stdout))
affinefile = baseoutput + "0GenericAffine.mat"
if inputs["dononlinear"]:
    fieldfile = baseoutput + "1Warp.nii.gz"
    invfieldfile = baseoutput + "1InverseWarp.nii.gz"
else:
    fieldfile, invfieldfile = (None, None)
print("Done.")


# Apply deformation
print("Apply deformation...")
def apply_transforms(moving_file, reference_file, outdir, fieldfile,
                     affinefile, dim, interp, inverse):
    basename = os.path.basename(moving_file).split(".")[0]
    if inverse:
        outfile = os.path.join(
            outdir, "ants_2WarpToNative_{0}.nii.gz".format(basename))
        cmd = [
            "antsApplyTransforms",
            "-d", str(dim),
            "--float", "1",
            "-i", moving_file,
            "-o", outfile,
            "-r", reference_file,
            "-n", interp]
        if fieldfile is not None:
            cmd += [
                "-t", "[{0},1]".format(affinefile),
                "-t", fieldfile,
                "-n", interp]
        else:
            cmd += [
                "-t", "[{0},1]".format(affinefile),]
    else:
        outfile = os.path.join(
            outdir, "ants_2WarpToTemplate_{0}.nii.gz".format(basename))
        cmd = [
            "antsApplyTransforms",
            "-d", str(dim),
            "--float", "1",
            "-i", moving_file,
            "-o", outfile,
            "-r", reference_file,
            "-n", interp]
        if fieldfile is not None:
            cmd += [
                "-t", fieldfile,
                "-t", affinefile,
                "-n", interp]
        else:
            cmd += [
                "-t", affinefile]
    return cmd, outfile

warpedfiles = []
commands = []
for index in range(nb_modalities):
    cmd, outfile = apply_transforms(
        inputs["images"][index], inputs["references"][index], inputs["outdir"],
        fieldfile, affinefile, inputs["dimensions"], "BSpline", False)
    commands.append(cmd)
    warpedfiles.append(outfile)
if inputs["apply_to"] is not None:
    for item in inputs["apply_to"]:
        path, use_inverse, interp = item.split(",")
        if use_inverse == "1":
            cmd, outfile = apply_transforms(
                path, inputs["images"][0], inputs["outdir"], invfieldfile,
                affinefile, inputs["dimensions"], interp_map[interp], True)
        else:
            cmd, outfile = apply_transforms(
                path, inputs["references"][0], inputs["outdir"], fieldfile,
                affinefile, inputs["dimensions"], interp_map[interp], False)
        commands.append(cmd)
        warpedfiles.append(outfile)
for cmd in commands:
    if verbose > 0:
        print("-" * 20)
        print(" ".join(cmd))
        print("-" * 20)
    environment = os.environ
    environment["ANTSPATH"] = inputs["binaries"]
    environment["PATH"] += ":" + inputs["binaries"]
    process = subprocess.Popen(cmd,
                               env=environment,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    exitcode = process.returncode
    if exitcode != 0:
        raise ValueError("Command '{0}' failed:: {1}".format(
            " ".join(cmd), stderr + stdout))
print("Done.")


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(inputs["outdir"], "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
params = locals()
outputs = dict([(name, params[name])
               for name in ("cmd0", "cmd1", "commands", "affinefile",
                            "invfieldfile", "fieldfile", "warpedfiles")])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    pprint("[info] Outputs:")
    pprint(outputs)


