#!/usr/bin/env python
"""
Create a HDF5 format file from a VCF using gnu parallel for parallization.
Parallelizes vcf2npy calls based on regions derived from a fasta file or
given directly using --regions-string or --regions-file.
Then runs vcfnpy2hdf5 serially on each chromosome/scaffold/contig.

Examples:
    vcf2hdf5_parallel --vcf in.vcf.gz --fasta ref.fa
    vcf2hdf5_parallel --vcf in.vcf.gz --fasta ref.fa --cleanup --jobs -1
    vcf2hdf5_parallel --vcf in.vcf.gz --regions-string '2L:0-20000 X' --out foo

Author: Travis Collier <https://github.com/travc>
"""

from __future__ import print_function
import sys
import os
import time
import argparse
import logging
import math
import shlex  # for split
import shutil  # for rmtree and which
import subprocess
import errno
from collections import OrderedDict

from numpy import __version__ as numpy_version
from h5py import __version__ as h5py_version
from vcfnp import __version__ as vcfnp_version

MAX_LOGGING_LEVEL = logging.CRITICAL
DEFAULT_LOGGING_LEVEL = logging.INFO


def setup_logger(verbose_level):
    fmt = ('%(levelname)s %(asctime)s [%(module)s:%(lineno)s %(funcName)s] :: '
           '%(message)s')
    logging.basicConfig(format=fmt,
                        level=max((0, min((MAX_LOGGING_LEVEL,
                                           DEFAULT_LOGGING_LEVEL - (
                                               verbose_level * 10))))))


def exit(retval=0):
    logging.info("All done... %g sec elapsed", time.time()-g_start_tic)
    sys.exit(retval)


def Main(argv=None):
    global g_start_tic
    g_start_tic = time.time()

    # parse arguments
    pyv = '.'.join(map(str, sys.version_info[:3]))
    epilog = "Version: vcfnp %s (Python %s, NumPy %s, h5py %s)" % \
             (vcfnp_version, pyv, numpy_version, h5py_version)
    parser = argparse.ArgumentParser(
        description=__doc__,
        epilog=epilog,
        formatter_class=argparse.RawDescriptionHelpFormatter,
        # formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    req_group = parser.add_argument_group(title="required")
    req_group.add_argument('--vcf',
                           required=True,
                           help="VCF file to extract data from")

    fasta_group = parser.add_argument_group(
        title="fasta file",
        description="regions will be derived from fasta file"
                    "\n(ignored if --regions-string or --regions-file "
                    "is given)")
    fasta_group.add_argument('--fasta',
                             default=None,
                             help="fasta reference to base regions on "
                                  "(ignored if --regions-string or "
                                  "--regions-file "
                                  "is given)")
    fasta_group.add_argument('--task-size',
                             default='500e3',
                             help="size in bp of regions derived from fasta "
                                  "(default=%(default)s)")

    regions_group = parser.add_argument_group(
        title="regions list or file",
        description="regions taken from string argument"
                    " and/or a regions file"
                    "\n(overrides --fasta)")
    regions_group.add_argument('--regions-string',
                               default=None,
                               help="string listing (whitespace separated) "
                                    "regions "
                                    "to extract in parallel "
                                    "(overrides --fasta)")
    regions_group.add_argument('--regions-file',
                               default=None,
                               help="file listing regions to extract in "
                                    "parallel "
                                    "(overrides --fasta)")
    regions_group.add_argument('--regions-column', type=int, default=0,
                               help="column (0-based) of regions file with "
                                    "region strings "
                                    "(default=%(default)s)")

    opt_group = parser.add_argument_group(title="options")
    opt_group.add_argument('--out',
                           default=None,
                           help="output prefix, "
                                "default is VCF stripped of path and "
                                "extension")

    opt_group.add_argument('--force', action='store_true',
                           help="remove existing cache directory and HDF5 "
                                "file if they exist (default is to abort)")
    opt_group.add_argument('--cleanup', action='store_true',
                           help="remove cache and log files when finished ")
    opt_group.add_argument('-v', '--verbose', action='count', default=0,
                           help="increase logging verbosity (count)")
    opt_group.add_argument('-q', '--quiet', action='count', default=0,
                           help="decrease logging verbosity (count)")
    opt_group.add_argument('-h', '--help', action='help',
                           help="show this help message and exit")

    opt_group.add_argument('--hdf5-compression', default='gzip',
                           help="compression method for HDF5 output file "
                                "(default=%(default)s)")
    opt_group.add_argument('--hdf5-compression-opts', default='1',
                           help="compression option (level) for HDF5 output "
                                "file (default=%(default)s)")
    opt_group.add_argument('--hdf5-chunk-size',
                           dest='hdf5_chunk_size', type=int,
                           default=2 ** 20, metavar='NBYTES',
                           help='HDF5 chunk size in bytes (defaults to 1Mb)')
    opt_group.add_argument('--hdf5-chunk-width',
                           dest='hdf5_chunk_width', type=int, metavar='N',
                           default=10,
                           help='HDF5 chunk width for 2 dimensional datasets ('
                                'defaults to 10)')

    parallel_options_group = parser.add_argument_group(
        title="options passed to parallel")
    parallel_options_group.add_argument('--jobs', default='100%',
                                        help="jobs argument passed to parallel"
                                             " (default=%(default)s)")
    parallel_options_group.add_argument('--nice', default='9',
                                        help="nice argument passed to parallel"
                                             " (default=%(default)s)")

    exec_group = parser.add_argument_group(title='executables')
    exec_group.add_argument('--parallel-exec', default='/usr/bin/parallel',
                            help="parallel executable (must be full path)"
                                 " (default=%(default)s)")
    exec_group.add_argument('--vcf2npy-exec', default='vcf2npy',
                            help="vcf2npy executable"
                                 " (default=%(default)s)")
    exec_group.add_argument('--vcfnpy2hdf5-exec', default='vcfnpy2hdf5',
                            help="vcfnpy2hdf5 executable"
                                 " (default=%(default)s)")

    args = parser.parse_args()

    # setup logger first
    setup_logger(verbose_level=args.verbose - args.quiet)
    logging.info("Start")

    # addtional argument handling and validation
    if args.out is None:
        args.out = os.path.splitext(os.path.basename(args.vcf))[0]
        if args.out.endswith('.vcf'):
            args.out = os.path.splitext(args.out)[0]
    logging.info("output prefix = '{}'".format(args.out))
    if args.regions_string is None and args.regions_file is None:
        if args.fasta is None:
            logging.critical("must specify either --fasta, "
                             "--regions-string, or --regions-file")
            sys.exit(2)
        args.task_size = int(math.ceil(float(args.task_size)))
        logging.info("task size = {}".format(args.task_size))
        if args.task_size < 1:
            logging.critical("task size must be > 0")
            sys.exit(2)
    # file checks
    hdf5_out_filename = "{}.h5".format(args.out)
    if os.path.lexists(hdf5_out_filename) and not args.force:
        logging.critical("output file '{}' already exists".format(
            hdf5_out_filename))
        sys.exit(2)
    # subprocesses should use same python executable as this script
    python_exec = sys.executable

    ## get region_list and chrom_list ##
    region_list = []
    chrom_list = OrderedDict()  # for quick hashing lookup
    if args.regions_string is None and args.regions_file is None:
        # derive regions from fasta file (or index of it)
        # try .fai (samtools faidx) file first
        if os.path.isfile(args.fasta + '.fai'):
            logging.info("deriving regions using 'samtools faidx' index file "
                         "'{}'".format(args.fasta + '.fai'))
            with open(args.fasta + '.fai') as infile:
                for l in infile:
                    l = l.partition('#')[0].rstrip()
                    if (not l): continue
                    l = l.split()
                    chrom_list[l[0]] = int(l[1])
        else:  # use pyfasta
            logging.info("Deriving regions using pyfasta from "
                         "'{}'".format(args.fasta))
            import pyfasta
            genome = pyfasta.Fasta(args.fasta)
            for chrom in genome:
                chrom_list[chrom] = len(genome[chrom])
        # make region_list from chrom_list (contains size of each chrom) 
        for chrom in chrom_list:
            chrom_size = chrom_list[chrom]
            n_fill = int(math.ceil(math.log10(chrom_size)))
            region_list.extend([
                chrom + ':' +
                str(x + 1).zfill(n_fill) + '-' +
                str(min((x + args.task_size,
                        chrom_size))).zfill(n_fill)
                for x in
                range(0, int(chrom_size), args.task_size)
            ])
    else:
        # regions from regions_string and/or regions_file
        if args.regions_string is None and args.regions_file is None:
            logging.critical("must specify either --fasta, "
                             "--regions-string, or --regions-file")
            sys.exit(2)
        # convert regions_string argument into region_list and chrom_list
        if args.regions_string is not None:
            for region in args.regions_string.split():
                region = region.strip()
                if not region: continue
                region_list.append(region)
                chrom = region.split(':')[0]
                if chrom not in chrom_list:
                    chrom_list[chrom] = True
        # load regions list from regions_file
        if args.regions_file is not None:
            logging.info("Load regions from file '{}'".format(
                args.regions_file))
            with open(args.regions_file, 'rU') as instream:
                for line in instream:
                    fields = shlex.split(line, comments=True)
                    if (not fields):
                        continue
                    region_list.append(fields[args.regions_column])
                    chrom = fields[args.regions_column].split(':')[0]
                    if chrom not in chrom_list:
                        chrom_list[chrom] = True

    # check chrom_list for invalid chars '_' and ':'
    # @TCC at least '_' could probably be allowed if vcf2npy output file nameing was changed (need a separator between chrom and position range)
    for c in chrom_list:
        if '__' in c:
            logging.critical("substring '__' in chrom name '{}' may break this script".format(c))
            sys.exit(2)
        if ':' in c:
            logging.critical("character ':' in chrom name '{}' will probably break this script".format(c))
            sys.exit(2)
             
    logging.debug("region_list={}".format(region_list))
    logging.debug("chrom_list={}".format(chrom_list))
    logging.info("decomposing into {} tasks".format(len(region_list)))

    # create the cache output directory
    outdir = args.out + '_vcf2hdf5-cache'
    logging.info("creating output directory '{}'".format(outdir))
    try:
        os.makedirs(outdir)
    except OSError as e:
        if e.errno == errno.EEXIST and os.path.isdir(outdir):
            if args.force:
                logging.warning("removeing existing output directory "
                                "'{}'".format(outdir))
                shutil.rmtree(outdir)
                os.makedirs(outdir)
            else:
                logging.critical("output directory '{}' already "
                                 "exists".format(outdir))
                sys.exit(2)
        else:
            raise

    ## parallel vcfnpy ##
    vcf2npy_start_tic = time.time()
    # setup the command for parallel vcf2npy
    cmd = [args.parallel_exec,
           '--gnu',
           '--halt-on-error', '1',
           '--ungroup',
           # '--load','90%', # doesn't seem to work very well
           '--colsep', '\t',
           '--jobs', args.jobs,
           '--nice', args.nice,
           ]
    if logging.getLogger().getEffectiveLevel() <= logging.INFO:
        cmd += ['--eta']
    # the vcf2npy part of the command (a single string)
    subcmd = ""
    subcmd += python_exec
    subcmd += " " + shutil.which(args.vcf2npy_exec)
    subcmd += " --vcf " + args.vcf
    subcmd += " --output-dir " + outdir
    subcmd += " --array-type {1} "
    subcmd += " --region {2} "
    subcmd += " >& {}".format(os.path.join(outdir, '{1}.{2}.log'))
    cmd += [subcmd]

    logging.info("excuting parallel vcf2npy")
    logging.debug("cmd={}".format(cmd))

    # start the parallel command/process
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE,
                            universal_newlines=True)

    # send each region from the region_list to parallel process
    for region in region_list:
        try:
            proc.stdin.write('variants\t' + region + '\n')
            proc.stdin.write('calldata_2d\t' + region + '\n')
        except IOError as e:
            if e.errno == errno.EPIPE:
                logging.error('Broken pipe')
                break
            elif e.errno == errno.EINVAL:
                logging.error('Invalid argument/pipe')
                break
            else:
                raise
    proc.stdin.close()
    proc.wait()
    logging.info("finished vcf2npy: {:g} sec elapsed".format(
        time.time() - vcf2npy_start_tic))

    ## serial vcfnpy2hdf5 ##
    vcfnpy2hdf5_start_tic = time.time()
    logging.info("creating HDF5 file '{}'".format(hdf5_out_filename))
    if os.path.lexists(hdf5_out_filename) and args.force:
        logging.warning("removing existing HDF5 file: "
                        "'{}'".format(hdf5_out_filename))
        os.unlink(hdf5_out_filename)
    # remove existing log file (each job will append)
    try:
        os.unlink(hdf5_out_filename + '.log')
    except OSError as e:
        if e.errno != errno.ENOENT:
            raise
    # run vcfnpy2hdf5 for each chrom
    for chrom in chrom_list:
        logging.info("adding '{}' to HDF5 file".format(chrom))
        cmd = [python_exec]
        cmd += [shutil.which(args.vcfnpy2hdf5_exec),
                '--vcf', args.vcf,
                '--input-dir', outdir,
                '--input-filename-template',
                '{{array_type}}.{}__*.npy'.format(chrom),
                '--group', chrom,
                '--output', hdf5_out_filename,
                '--compression', args.hdf5_compression,
                '--compression-opts', str(args.hdf5_compression_opts),
                '--chunk-size', str(args.hdf5_chunk_size),
                '--chunk-width', str(args.hdf5_chunk_width),
                ]
        logging.debug("cmd={}".format(cmd))
        with open(hdf5_out_filename + '.log', 'a') as logfile:
            print("#" * 79, file=logfile)
            print(cmd, file=logfile)
            logfile.flush()
            proc = subprocess.call(cmd,
                                   stdout=logfile, stderr=subprocess.STDOUT)
    logging.info("finished vcfnpy2hdf5: {:g} sec elapsed".format(
        time.time() - g_start_tic))

    # optional cleanup
    if args.cleanup:
        logging.info("cleaning up... removing cache and logs")
        shutil.rmtree(outdir)
        os.unlink(hdf5_out_filename + '.log')

    exit(0)


#########################################################################
# Main loop hook... if run as script run main, else this is just a module
if __name__ == "__main__":
    exit(Main(argv=None))

