#!/usr/bin/env python3

import argparse

import tabulate
import fastmap

USAGE = """Fastmap CLI for running arbitrary Python functions on the cloud.
This CLI only supports the offload workflow and doesn't have any mapping functionality.
Run `fastmap --help` for more details.
"""

DESCRIPTION = """
Fastmap CLI for offloading scripts.
\n\n
Examples:
$ fastmap poll <task_id>
$ fastmap poll
$ fastmap logs <task_id>
$ fastmap traceback <task_id>
$ fastmap return_value <task_id>
$ fastmap kill <task_id>
$ fastmap clear <task_id>
$ fastmap clear --force

"""

EPILOG = """Run `fastmap <operation> --help` for help on individual operations."""


def _relative_time(seconds):
    if not seconds:
        return 'never'
    if seconds > 60 * 60 * 24 * 2:
        return '%d days ago' % (seconds // (60 * 60 * 24))
    elif seconds > 60 * 60 * 2:
        return '%d hours ago' % (seconds // (60 * 60))
    elif seconds > 120:
        return '%d minutes ago' % (seconds // 60)
    elif seconds > 2:
        return '%d seconds ago' % seconds
    else:
        return 'just now'


def _prettify_tasks(tasks):
    for task in tasks:
        task['type'] = task['type'].lower().capitalize()
        task['start_time'] = task['start_time'].strftime("%Y-%m-%d %H:%M:%S")
        task['runtime'] = task['runtime'] and ("%.1fs" % task['runtime'])
        task['last_heartbeat'] = _relative_time(task['last_heartbeat'])
        # task['state'] = task['task_state']
        # task['id'] = task['task_id']

        # del task['last_heartbeat']
        # del task['task_state']
        # del task['task_id']


        # task['progress'] = task['progress'] and "%.1f%%" % task['progress']


# def offload(config, path, function_name, label):
#     if not os.path.exists(path):
#         raise AssertionError("Path %r does not exist." % path)
#     mod_path = path.replace('/', '.')
#     if mod_path.endswith('.py'):
#         mod_path = mod_path[:-3]
#     sys.path.append(os.getcwd())
#     try:
#         mod = importlib.import_module(mod_path)
#     except ImportError:
#         raise AssertionError("Could not import module %r" % mod_path) from None

#     # TODO this doesn't actually work...
#     try:
#         func = getattr(mod, function_name)
#     except AttributeError:
#         raise AssertionError("Could not import function %r from module %r" %
#                              (func, mod_path)) from None
#     fastmap_task = config.offload(func, label=label)
#     config.log.info("Started new task: %s" % fastmap_task.task_id)


def poll(config, task_id):
    if task_id:
        tasks = [config.poll(task_id)]
    else:
        tasks = config.poll_all()

    tasks.sort(key=lambda x: x['start_time'], reverse=True)
    _prettify_tasks(tasks)

    config.log.info("Found %d task(s)" % len(tasks))
    print(tabulate.tabulate(tasks, headers='keys'))


def return_value(config, task_id):
    print(config.return_value(task_id))


def traceback(config, task_id):
    tb = config.traceback(task_id)
    config.log.info("Traceback for %s:" % task_id)
    print(tb)


def kill(config, task_id, force=False):
    if task_id:
        config.kill(task_id)
        return
    config.log.info("Polling for tasks to kill...")
    tasks = config.poll_all()
    tasks_to_kill = [t for t in tasks if t['task_state'] in ("PENDING", "PROCESSING")]
    if not tasks_to_kill:
        config.log.info("Could not find any tasks to kill")
        return
    if not force:
        if config.log.input("Kill %d tasks? (y/n) " % len(tasks_to_kill)).lower() != 'y':
            return
    for task in tasks_to_kill:
        config.kill(task['task_id'])
    config.log.info("Killed %d tasks" % len(tasks_to_kill))


def logs(config, task_id):
    logs = config.all_logs(task_id)
    config.log.info("Logs for %s:" % task_id)
    print(logs)


def retry(config, task_id):
    new_task = config.retry(task_id)
    config.log.info("Retry in process %r" % new_task)


def clear(config, task_id, force=False):
    if task_id:
        cleared_tasks = [config.clear(task_id)]
        return

    config.log.info("Polling for tasks to clear...")
    tasks = config.poll_all()
    tasks_to_clear = [t for t in tasks if t['task_state'] == "DONE"]

    if not tasks_to_clear:
        config.log.info("Could not find any tasks to clear")
        return

    if not force:
        if config.log.input("Clear all 'DONE' tasks? There are currently %d. (y/n) " % len(tasks_to_clear)).lower() != 'y':
            return
    cleared_tasks = config.clear_all()
    _prettify_tasks(cleared_tasks)
    print(tabulate.tabulate(cleared_tasks, headers='keys'))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        # usage=DESCRIPTION,
        description=DESCRIPTION,
        epilog=EPILOG,
        formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument(
        "--config",
        help="Location of configuration file generated by depoly_gcp.py. "
             "If omitted, will attempt to use the default configuration. ")
    parser.add_argument(
        "--verbosity",
        choices=("QUIET", "NORMAL", "LOUD"),
        help="How loud fastmap is. Default is NORMAL.",
        default="NORMAL")

    subparsers = parser.add_subparsers(
        dest='operation', required=True,
        help='sub-command help')

    # offload_p = subparsers.add_parser(
    #     'offload',
    #     help="Offload a function in a python file.")
    # offload_p.add_argument(
    #     "path",
    #     help="The python file. E.g. path/script.py")
    # offload_p.add_argument(
    #     "function",
    #     help="The name of the function in the file. E.g. main_function")
    # offload_p.add_argument(
    #     "label", nargs='?',
    #     help="Optional label for your use")

    poll_p = subparsers.add_parser(
        'poll', help="Get the metadata of one or all tasks")
    poll_p.add_argument(
        "task_id", nargs='?',
        help="Which task to return specifically. If omitted, return all non-CLEARED tasks")

    logs_p = subparsers.add_parser(
        'logs',
        help="Get logs of a task. Task can be in any state except CLEARED. "
             "Logs may be truncated. ")
    logs_p.add_argument(
        "task_id",
        help="Task ID of task to get logs for.")

    return_value_p = subparsers.add_parser(
        'return_value',
        help="Get the return_value of a task in a DONE state.")
    return_value_p.add_argument(
        "task_id",
        help="Task ID")

    traceback_p = subparsers.add_parser(
        'traceback',
        help="Get the traceback of a task in a DONE state with an ERROR outcome.")
    traceback_p.add_argument(
        "task_id",
        help="Task ID")

    kill_p = subparsers.add_parser(
        'kill',
        help="Kill a running task")
    kill_p.add_argument(
        "task_id", nargs='?',
        help="If omitted, kill all tasks")
    kill_p.add_argument(
        '--force', action='store_true',
        help='When task_id is omitted, kill all tasks without confirmation')

    retry_p = subparsers.add_parser(
        'retry',
        help='Retry a task in a DONE state')
    retry_p.add_argument(
        "task_id",
        help="Task to retry")

    clear_p = subparsers.add_parser(
        'clear',
        help="Clear a completed task")
    clear_p.add_argument(
        "task_id", nargs='?',
        help="If omitted, clear all tasks")
    clear_p.add_argument(
        '--force', action='store_true',
        help='When task_id is omitted, clear all tasks without confirmation')

    args = parser.parse_args()

    config = fastmap.init(
        config=args.config,
        verbosity=args.verbosity)

    if config.exec_policy == fastmap.ExecPolicy.LOCAL:
        raise AssertionError("The fastmap CLI does not support a LOCAL exec_policy. "
                             "Check your configuration file.")

    # if args.operation == 'offload':
    #     offload(config, args.path, args.function, args.label)
    if args.operation == 'poll':
        poll(config, args.task_id)
    if args.operation == 'return_value':
        return_value(config, args.task_id)
    if args.operation == 'traceback':
        traceback(config, args.task_id)
    if args.operation == 'kill':
        kill(config, args.task_id, args.force)
    if args.operation == 'retry':
        retry(config, args.task_id)
    if args.operation == 'logs':
        logs(config, args.task_id)
    if args.operation == 'clear':
        clear(config, args.task_id, args.force)
