#!/usr/bin/env python
from __future__ import print_function, division, absolute_import
import os
import sys
import argparse
from glob import glob
from datetime import datetime
import operator


import numpy as np
import h5py


import vcfnp
import vcfnp.vcflib as vcflib
from vcfnp.compat import b as _b


PY2 = sys.version_info[0] == 2
if not PY2:
    from functools import reduce


def log(*msg):
    print('[vcfnpy2hdf5] ' +
          str(datetime.now()) +
          ' :: ' +
          ' '.join(map(str, msg)),
          file=sys.stderr)
    sys.stderr.flush()


def load_compound_arrays_into_dataset(input_filenames, dataset_name,
                                      parent_group, chunk_size, chunk_width,
                                      compression, compression_opts, shuffle,
                                      fletcher32, scaleoffset):

    a = np.load(input_filenames[0])

    if dataset_name in parent_group:
        log(dataset_name, 'delete existing dataset')
        del parent_group[dataset_name]

    log(dataset_name, 'setup dataset')

    # determine shape
    shape = (0,) + a.shape[1:]  # initially empty
    maxshape = (None,) + a.shape[1:]  # resizable along first dimension

    # determine chunk shape
    if a.ndim == 1:
        chunks = chunk_size // a.dtype.itemsize,
    elif a.ndim > 1:
        chunk_width = min(chunk_width, a.shape[1])
        row_size = reduce(
            operator.mul,
            (a.dtype.itemsize, chunk_width) + a.shape[2:]
        )
        chunk_height = chunk_size // row_size
        chunks = (chunk_height, chunk_width) + a.shape[2:]
    else:
        raise Exception('unexpected number of dimensions: %r' % a)

    log(
        dataset_name, 'creating dataset',
        'dtype', a.dtype,
        'shape', shape,
        'maxshape', maxshape,
        'chunks', chunks,
        'chunk size', a.dtype.itemsize * reduce(operator.mul, chunks),
        'compression', compression,
        'compression_opts', compression_opts,
        'shuffle', shuffle,
        'fletcher32', fletcher32,
        'scaleoffset', scaleoffset
    )
    ds = parent_group.create_dataset(dataset_name, shape=shape, dtype=a.dtype,
                                     chunks=chunks, maxshape=maxshape,
                                     compression=compression,
                                     compression_opts=compression_opts,
                                     shuffle=shuffle, fletcher32=fletcher32,
                                     scaleoffset=scaleoffset)

    log(dataset_name, 'load data')
    n = 0
    for fn in input_filenames:
        a = np.load(fn)
        if a.size > 0:
            n_add = a.shape[0]
            n_new = n + n_add
            log(dataset_name, 'loading', n, n_new)
            ds.resize(n_new, axis=0)
            ds[n:n_new] = a
            n = n_new

    return ds


def load_compound_arrays_into_group(input_filenames, group_name, parent_group,
                                    chunk_size, chunk_width, compression,
                                    compression_opts, shuffle, fletcher32,
                                    scaleoffset):

    a = np.load(input_filenames[0])
    grp = parent_group.require_group(group_name)

    log(group_name, 'setup datasets')
    for f in a.dtype.names:

        if f in grp:
            log(group_name, f, 'deleting existing dataset')
            del grp[f]

        # determine shape
        shape = (0,) + a[f].shape[1:]  # initially empty
        maxshape = (None,) + a[f].shape[1:]  # resizable along first dimension

        # determine chunk shape
        if a[f].ndim == 1:
            chunks = chunk_size // a[f].dtype.itemsize,
        elif a[f].ndim > 1:
            chunk_width = min(chunk_width, a[f].shape[1])
            row_size = reduce(
                operator.mul,
                (a[f].dtype.itemsize, chunk_width) + a[f].shape[2:]
            )
            chunk_height = chunk_size // row_size
            chunks = (chunk_height, chunk_width) + a[f].shape[2:]
        else:
            raise Exception('unexpected number of dimensions: %r' % a[f])

        log(
            group_name, f, 'creating dataset',
            'dtype', a[f].dtype,
            'shape', shape,
            'maxshape', maxshape,
            'chunks', chunks,
            'chunk size', a[f].dtype.itemsize * reduce(operator.mul, chunks),
            'compression', compression,
            'compression_opts', compression_opts,
            'shuffle', shuffle,
            'fletcher32', fletcher32,
            'scaleoffset', scaleoffset
        )
        grp.create_dataset(f, shape=shape, dtype=a[f].dtype, chunks=chunks,
                           maxshape=maxshape, compression=compression,
                           compression_opts=compression_opts, shuffle=shuffle,
                           fletcher32=fletcher32, scaleoffset=scaleoffset)

    log(group_name, 'load data')
    n = 0
    for fn in input_filenames:
        a = np.load(fn)
        if a.size > 0:
            n_add = a.shape[0]
            n_new = n + n_add
            log(group_name, 'loading', n, n_new)
            for f in a.dtype.names:
                grp[f].resize(n_new, axis=0)
                grp[f][n:n_new, ...] = a[f]
            n = n_new

    return grp


def load_hdf5(input_dir, output_filename,
              input_filename_template='{array_type}*.npy', vcf_filename=None,
              variants_only=False, calldata_only=False, tabulate_variants=False,
              parent_group_name='/', chunk_size=2**20, chunk_width=10,
              compression=None, compression_opts=None, shuffle=False,
              fletcher32=False, scaleoffset=None):

    # guard conditions
    assert input_dir is not None and os.path.exists(input_dir)
    if vcf_filename is not None:
        assert os.path.exists(vcf_filename)

    # template for input arrays
    input_path_template = os.path.join(input_dir, input_filename_template)

    log('open hdf5 file', output_filename)
    with h5py.File(output_filename, 'a') as h5f:

        log('setup parent group', parent_group_name)
        parent_group = h5f.require_group(parent_group_name)

        if vcf_filename is not None:
            log('store metadata from VCF', vcf_filename)
            vcf = vcflib.PyVariantCallFile(vcf_filename)
            if 'samples' in parent_group:
                del parent_group['samples']
            samples = np.array(_b(vcf.sample_names))
            parent_group.create_dataset('samples', data=samples)
        else:
            log('no VCF file provided, skipping metadata storage')

        # find input files
        variants_input_filenames = sorted(
            glob(input_path_template.format(array_type='variants'))
        )
        calldata_input_filenames = sorted(
            glob(input_path_template.format(array_type='calldata_2d'))
        )
        nvf = len(variants_input_filenames)
        nc2df = len(calldata_input_filenames)
        log('variants found', nvf, 'array files')
        log('calldata found', nc2df, 'array files')

        if not (calldata_only or variants_only):
            assert nvf == nc2df, "number of files doesn't match"

        if not calldata_only:

            log('process variants')

            if tabulate_variants:
                log('variants use dataset layout')
                load_compound_arrays_into_dataset(variants_input_filenames,
                                                  'variants', parent_group,
                                                  chunk_size=chunk_size,
                                                  chunk_width=chunk_width,
                                                  compression=compression,
                                                  compression_opts=compression_opts,
                                                  shuffle=shuffle,
                                                  fletcher32=fletcher32,
                                                  scaleoffset=scaleoffset)
            else:
                log('variants use group layout')
                load_compound_arrays_into_group(variants_input_filenames,
                                                'variants', parent_group,
                                                chunk_size=chunk_size,
                                                chunk_width=chunk_width,
                                                compression=compression,
                                                compression_opts=compression_opts,
                                                shuffle=shuffle,
                                                fletcher32=fletcher32,
                                                scaleoffset=scaleoffset)

        if not variants_only:

            log('process calldata')

            load_compound_arrays_into_group(
                calldata_input_filenames, 'calldata', parent_group,
                chunk_size=chunk_size,
                chunk_width=chunk_width,
                compression=compression,
                compression_opts=compression_opts,
                shuffle=shuffle,
                fletcher32=fletcher32,
                scaleoffset=scaleoffset
            )

        log('all done')


def main():

    # handle command line args
    pyv = '.'.join(map(str, sys.version_info[:3]))
    epilog = "Version: vcfnp %s (Python %s, NumPy %s, h5py %s)" % \
             (vcfnp.__version__, pyv, np.__version__, h5py.__version__)
    parser = argparse.ArgumentParser(epilog=epilog)
    parser.add_argument('--vcf',
                        dest='vcf_filename', metavar='VCF', default=None,
                        help='VCF file to extract metadata from')
    parser.add_argument('--input-dir',
                        dest='input_dir', metavar='DIR', default=None,
                        help='input directory containing npy files')
    parser.add_argument('--input-filename-template',
                        dest='input_filename_template', metavar='TEMPLATE',
                        default='{array_type}*.npy',
                        help='template for input file names, defaults to '
                             '"{array_type}*.npy"')
    parser.add_argument('--output',
                        dest='output_filename', metavar='HDF5', default=None,
                        help='name of output HDF5 file')
    parser.add_argument('--group',
                        metavar='GROUP', dest='group', default='/',
                        help='destination group in HDF5 file, defaults to '
                             'root group')
    parser.add_argument('--chunk-size',
                        dest='chunk_size', type=int, metavar='NBYTES',
                        default=2**20,
                        help='chunk size (defaults to 1Mb)')
    parser.add_argument('--chunk-width',
                        dest='chunk_width', type=int, metavar='N',
                        default=10,
                        help='chunk width for 2 dimensional datasets ('
                             'defaults to 10)')
    parser.add_argument('--compression',
                        dest='compression', metavar='NAME', default='gzip',
                        help='compression, default is gzip')
    parser.add_argument('--compression-opts',
                        dest='compression_opts', metavar='LEVEL', type=int,
                        default=1,
                        help='compression level, applies only to gzip ('
                             'defaults to 1)')
    parser.add_argument('--shuffle',
                        dest='shuffle', action='store_true',
                        default=False,
                        help='use the shuffle filter (defaults to False)')
    parser.add_argument('--fletcher32',
                        dest='fletcher32', action='store_true',
                        default=False,
                        help='use the Fletcher32 checksum filter (defaults to '
                             'False)')
    parser.add_argument('--scaleoffset',
                        dest='scaleoffset', metavar='N', type=int,
                        default=None,
                        help='use the scale-offset filter (defaults to None)')
    parser.add_argument('--variants-only',
                        dest='variants_only', action='store_true',
                        default=False,
                        help="load variants only, don't look for calldata")
    parser.add_argument('--calldata-only',
                        dest='calldata_only', action='store_true',
                        default=False,
                        help="load calldata only, don't look for variants")
    parser.add_argument('--tabulate-variants',
                        dest='tabulate_variants', action='store_true',
                        default=False,
                        help="organise the variants as a dataset with a "
                             "compound dtype (defaults to False - organise "
                             "variants as a group, with each column as a "
                             "separate dataset)")
    args = parser.parse_args()

    compression = args.compression
    compression_opts = args.compression_opts if compression == 'gzip' else None

    load_hdf5(input_dir=args.input_dir, output_filename=args.output_filename,
              input_filename_template=args.input_filename_template,
              vcf_filename=args.vcf_filename, parent_group_name=args.group,
              chunk_size=args.chunk_size, chunk_width=args.chunk_width,
              compression=compression, compression_opts=compression_opts,
              shuffle=args.shuffle, fletcher32=args.fletcher32,
              scaleoffset=args.scaleoffset, variants_only=args.variants_only,
              calldata_only=args.calldata_only,
              tabulate_variants=args.tabulate_variants)


if __name__ == '__main__':
    main()
