#!/usr/bin/python -u

"""Computes the MCD metric for two sequences of mel cepstra."""

# Copyright 2014, 2015, 2016, 2017 Matt Shannon

# This file is part of mcd.
# See `License` for details of license and warranty.

import os
import sys
import argparse
import re
import math
import numpy as np

from htk_io.base import DirReader
import htk_io.alignment as alio
import htk_io.vecseq as vsio

from mcd import util
import mcd.metrics_fast as mt

def main(rawArgs):
    parser = argparse.ArgumentParser(
        description=(
            'Computes the MCD metric for two sequences of mel cepstra.'
            ' Mel cepstral distortion (MCD) is a measure of the difference'
            ' between two sequences of mel cepstra.'
            ' This utility computes the MCD between two sequences of equal'
            ' length that are assumed to already be "aligned" in terms of'
            ' their timing.'
            ' Optionally certain segments (e.g. silence) can be omitted from'
            ' the calculation.'
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '--ext', dest='ext', default='mgc', metavar='EXT',
        help=(
            'file extension added to uttId to get file containing speech'
            ' parameters'
        )
    )
    parser.add_argument(
        '--param_order', dest='paramOrder', default=40, type=int,
        metavar='ORDER',
        help='parameter order of the cepstral files'
    )
    parser.add_argument(
        '--remove_segments', dest='removeSegments', default=None,
        metavar='LABELREGEX',
        help='regex of segment labels to remove'
    )
    parser.add_argument(
        '--alignment_dir', dest='alignmentDir', default=None,
        metavar='ALIGNMENTDIR',
        help=(
            'directory containing phone-level alignment files (used for'
            ' segment removal)'
        )
    )
    parser.add_argument(
        '--frame_period', dest='framePeriod', default=0.005, type=float,
        metavar='FRAMEPERIOD',
        help='frame period in seconds (used for segment removal)'
    )
    parser.add_argument(
        dest='natDir', metavar='NATDIR',
        help='directory containing natural speech parameters'
    )
    parser.add_argument(
        dest='synthDir', metavar='SYNTHDIR',
        help='directory containing synthetic speech parameters'
    )
    parser.add_argument(
        dest='uttIds', metavar='UTTID', nargs='+',
        help='utterance ids (ext will be appended to these)'
    )
    args = parser.parse_args(rawArgs[1:])
    assert (args.removeSegments is None) == (args.alignmentDir is None)
    if args.removeSegments is not None:
        print ('NOTE: removing segments matching regex \'%s\' using alignments'
               ' in %s' % (args.removeSegments, args.alignmentDir))
    reRemoveSegments = (None if args.removeSegments is None
                        else re.compile(args.removeSegments))

    costFn = mt.logSpecDbDist

    alignmentIo = alio.AlignmentIo(args.framePeriod)
    getAlignment = DirReader(alignmentIo, args.alignmentDir, 'lab')

    vecSeqIo = vsio.VecSeqIo(args.paramOrder)
    getNatVecSeq = DirReader(vecSeqIo, args.natDir, args.ext)
    getSynthVecSeq = DirReader(vecSeqIo, args.synthDir, args.ext)

    costTot = 0.0
    framesTot = 0
    for uttId in args.uttIds:
        print 'processing', uttId
        nat = getNatVecSeq(uttId)
        synth = getSynthVecSeq(uttId)
        # ignore 0th cepstral component
        nat = nat[:, 1:]
        synth = synth[:, 1:]

        assert len(nat) == len(synth)

        if reRemoveSegments is None:
            cost = sum([
                costFn(natFrame, synthFrame)
                for natFrame, synthFrame in zip(nat, synth)
            ])
            frames = len(nat)
        else:
            alignment = getAlignment(uttId)
            alignmentInclude = [
                (startTime, endTime, not reRemoveSegments.search(label))
                for startTime, endTime, label, _ in alignment
            ]
            includeFrames = list(util.expandAlignment(alignmentInclude))
            assert len(includeFrames) == len(nat)
            costs = [
                costFn(natFrame, synthFrame)
                for natFrame, synthFrame, includeFrame in zip(nat, synth,
                                                              includeFrames)
                if includeFrame
            ]
            cost = sum(costs)
            frames = len(costs)

        costTot += cost
        framesTot += frames

    print 'overall MCD = %f (%d frames)' % (costTot / framesTot, framesTot)

if __name__ == '__main__':
    main(sys.argv)
