#!/usr/bin/env python

# Copyright (c) 2011-2016 Timothy Savannah under GPLv3, All Rights Reserved. See LICENSE for more information
"""
Disttask is a utility which provides the ability to distribute a task across a fixed number of processes, for better utilization of multiprocessing.

Use it with existing single-threaded/process tools and scripts to take full advantage of your computer's resources.

"""

import os
import sys
import select
import signal
import subprocess
import threading
import time

from collections import deque

__version__ = '2.2.0'

__version_tuple__ = (2, 2, 0)

try:
    bytes
except:
    bytes = str # Python < 2.6
    
if bytes == str:
    # Python 2, no additional decoding necessary.
    tostr = str
else:
    # Python 3, additional decoding necessary
    try:
        defaultEncoding = sys.getdefaultencoding()
    except:
        defaultEncoding = 'utf-8'
    
    def tostr(x):
        if isinstance(x, str) is True:
            return x
        if isinstance(x, bytes) is False:
            return str(x)
        return x.decode(defaultEncoding)


class StdoutWriter(threading.Thread):

    # FLUSH_EVERY - Explicitly flush after this many items.
    FLUSH_EVERY = 1

    def __init__(self, *args, **kwargs):
        threading.Thread.__init__(self, *args, **kwargs)

        self.stdoutData = deque()

        self.keepGoing = True

    def addData(self, data):
        self.stdoutData.append(data)

    def setFlushEvery(self, nWrites):
        self.FLUSH_EVERY = nWrites

    def run(self):
        time.sleep(.001) # Block immediatly whilst setup happens
        stdoutData = self.stdoutData

        flushEvery = self.FLUSH_EVERY

        try:
            writeOutput = sys.stdout.buffer.write
        except:
            writeOutput = sys.stdout.write

        while self.keepGoing is True or len(stdoutData) > 0:
            i = 0
            while len(stdoutData) > 0:
                nextItem = stdoutData.popleft()
                writeOutput(nextItem)
                i += 1
                if i >= flushEvery:
                    i = 0
                    sys.stdout.flush()

            sys.stdout.flush()
            time.sleep(.0005)

class Runner(threading.Thread):

    def __init__(self, cmd, stdoutWriter, thisItem, collateOutput=True):
        threading.Thread.__init__(self)
        self.cmd = cmd
        self.stdoutWriter = stdoutWriter

        self.thisItem = thisItem
        self.collateOutput = collateOutput

        self.keepGoing = True

    def run(self):
        pipe = subprocess.Popen(self.cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        if self.collateOutput is True:
            output = []
            def handleLine(line):
                output.append(line)
        else:
            thisItem = self.thisItem

            if sys.version_info.major >= 3:
                def handleLine(line):
                    prefix = ('[%s] ' %(thisItem,)).encode(defaultEncoding)
                    self.stdoutWriter.addData(prefix + line)
            else:
                def handleLine(line):
                    self.stdoutWriter.addData('[%s] %s' %(thisItem, line))

        pipeStdout = pipe.stdout
        i = 0
        while self.keepGoing is True and (not pipeStdout.closed or pipe.poll() is not None):
            try:
                (rlist, wlist, errors) = select.select([pipeStdout], [], [pipeStdout], .004)
                if errors:
                    try:
                        pipeStdout.close()
                    except:
                        pass
                    break

                if not rlist:
                    time.sleep(.002)
                    continue

                line = pipeStdout.readline()
                    
                if line == b'':
                    break

                handleLine(line)
            except Exception as e:
                keepGoing = False
                pipe.terminate()
                sys.stderr.write('Got exception: %s\n' %(str(e),))
                break
        pipe.wait()
        if self.collateOutput is True:
            try:
                self.stdoutWriter.addData(''.join(output))
            except:
                for item in output:
                    self.stdoutWriter.addData(item)

class DistTask(object):
    def __init__(self, cmd, concurrent_tasks, argset, stdoutWriter, endWhenDone=True, collateOutput=True):
        self.cmd = cmd
        self.concurrent_tasks = concurrent_tasks or len(argset)
        self.argset = deque(argset)
        self.stdoutWriter = stdoutWriter
        self.endWhenDone = endWhenDone

        if self.endWhenDone is False:
            self.keepGoing = True

        self.collateOutput = collateOutput

        # keepGoing is an attribute when end

    def addToArgset(self, items):
        self.argset += items

    def addItemToArgset(self, item):
        self.argset.append(item)

    def run(self):
        argset = self.argset
        for i in range(self.concurrent_tasks):
            pipes.append(None)

        pipesRunning = -1

        stdoutWriter = self.stdoutWriter

        if self.endWhenDone is True:
            shouldKeepGoing = lambda : bool(pipesRunning != 0)
        else:
            shouldKeepGoing = lambda : bool(self.keepGoing is True or (len(self.argset) > 0 or pipesRunning > 0))

        collateOutput = self.collateOutput

        while shouldKeepGoing():
            pipesRunning = 0
            for i in range(self.concurrent_tasks):
                if pipes[i] is None:
                    if len(argset) > 0:
                        nextItem = argset.popleft()
                        cmd = self.cmd.replace('%s', nextItem).replace('%d', str(i))
                        pipes[i] = Runner(cmd, stdoutWriter, nextItem, collateOutput)
                        pipes[i].start()
                        pipesRunning += 1
                else:
                    if pipes[i].isAlive() is False:
                        if len(argset) > 0:
                            nextItem = argset.popleft()
                            cmd = self.cmd.replace('%s', nextItem).replace('%d', str(i))
                            pipes[i].join() # cleanup
                            pipes[i] = Runner(cmd, stdoutWriter, nextItem, collateOutput)
                            pipes[i].start()
                            pipesRunning += 1
                    else:
                        pipesRunning += 1

            time.sleep(.0002)

        stdoutWriter.keepGoing = False

if (__name__ == "__main__"):
    args = sys.argv[1:]

    collateOutput = True
    if '-nc' in args:
        args.remove('-nc')
        collateOutput = False
    if '--no-collate' in args:
        args.remove('--no-collate')
        collateOutput = False

    if '--version' in args:
        sys.stderr.write('disttask version %s by Tim Savannah\n' %(__version__,))
        sys.exit(0)

    if len(args) < 3 or '--help' in args:
        sys.stderr.write("Usage: " + os.path.basename(sys.argv[0]) + " [cmd] [concurrent tasks] [argset]\n\n")
        sys.stderr.write("Use a %s in [cmd] where you want the args to go. use %d for the pipe number.\nTo run a list of commands (job server), have '%s' be your full command.\n\n")
        sys.stderr.write('''
    Options:

       -nc or --no-collate          By default, the output will be held until the task is completed, so output is not intermixed.
                                       By providing "-nc" or "--no-collate", instead each line that comes in from any running task
                                       is printed, prefixed with the argset in square-brackets (e.x.  "[arg1] Some message"


   Argsets from stdin

      If argset is '--', then the argset items will be read from stdin instead of being provided on the commandline.
      Execution begins immediately, so you can use disttask as a job manager with another process feeding in items
      as they become available.


   Max Concurrency

      You may use "0" or "MAX" as the "concurrent tasks" parameter to execute all items in the argset simultaneously


   Example Usage

      disttask "ssh root@%s hostname" 3 host1 host2 host3 host4 host5 host6 # Connect and get hostname on 6 hosts, 3 at a time.
''')

        sys.stderr.write("\ndisttask version " + __version__ + "\n")
        sys.exit(1)


    pipes = []


    cmd = args.pop(0)
    if '%s' not in cmd:
        sys.stderr.write("No %s in command!\n")
        sys.exit(127)

    concurrent_tasks = args.pop(0)
    if concurrent_tasks.lower() == 'max':
        concurrent_tasks = 0

    elif concurrent_tasks.isdigit() is False:
        sys.stderr.write('Number of concurrent tasks must be an integer, not "%s"\n' %(concurrent_tasks, ))
        sys.exit(127)

    concurrent_tasks = int(concurrent_tasks)
    argset = args
 
    if len(argset) == 1 and argset[0] == '--' and concurrent_tasks == 0:
        sys.stderr.write('concurrent tasks = 0 (MAX) is not supported with input from stdin.\n')
        sys.exit(127)

    stdoutWriter = StdoutWriter()
    if collateOutput is False:
        stdoutWriter.setFlushEvery(10)
    stdoutWriter.start()

    if len(argset) == 1 and argset[0] == '--':

        runner = DistTask(cmd, concurrent_tasks, [], stdoutWriter, endWhenDone=False, collateOutput=collateOutput)
        runnerThread = threading.Thread(target=runner.run)
        runnerThread.start()

        nextItem = None
        while not sys.stdin.closed:
            try:
                nextItem = sys.stdin.readline()
                if nextItem == '':
                    break
            except:
                break
            runner.addItemToArgset(nextItem[:-1])

        runner.keepGoing = False
        runnerThread.join()
    else:
        runner = DistTask(cmd, concurrent_tasks, argset, stdoutWriter, endWhenDone=True, collateOutput=collateOutput)
        runner.run()

# vim: set ts=4 sw=4 expandtab
