#!/usr/local/bin/python3.11

import os
import pkgutil
import readline
import importlib
import sys
import inspect

readline.parse_and_bind("tab: complete")

def debug(*args,**kwargs):
    #print(*args,**kwargs)
    pass


class completer(object):
    def __init__(self,lst):
        self.lst=lst
    def complete(self,txt,state):
        matches=0
        for a in self.lst:
            if a.startswith(txt):
                if state == matches:
                    return a
                matches += 1
        return None
class boutmodules(object):
    def __init__(self,lst):
        self.modules=[]
        for a in lst:
            self.parseModule(a)
        debug(self.modules)

    def parseModule(self,name):
        try:
            mod = importlib.import_module(name)
        except:
            debug("Failed to import %s"%name)
            return
        try:
            for loader, module_name, is_pkg in pkgutil.walk_packages(path=mod.__path__):
                self.parseModule(name+"."+module_name)
        except AttributeError:
            debug("Failed to scan %s for submodues"%name)
        for b in inspect.getmembers(mod):
            if inspect.isfunction(b[1]):
                if not name in self.modules:
                    self.modules.append(name)
                return

    def returnCached(self,state):
        if len(self.cache) > state:
            return self.cache[state]
        else:
            return None

    def scanFunctions(self,module):
        self.functions=[]
        self.funcs={}
        mod=importlib.import_module(module)
        for b in inspect.getmembers(mod):
            if inspect.isfunction(b[1]):
                self.functions.append(b[0])
                self.funcs[b[0]]=b[1]



def esc(astr):
    astr=str(astr)
    if '"' in astr:
        if "'" in astr:
            astr=astr.replace("\\","\\\\")
            return('"'+astr.replace('"','\"')+'"')
        return "'"+astr+"'"
    astr=astr.replace("\\","\\\\")
    return '"'+astr+'"'

def parse_type(typestr):
    typearr_=typestr.split(',')
    typearr=[]
    for t in typearr_:
        for t2 in t.split(' or '):
            typearr.append(t2)
    opt=False
    if "optional" in typearr[-1]:
        opt=True
        typearr=typearr[:-1]
    for type_ in typearr:
        knowns={"str" : "str",
                "int" : "int",
                "slice" : "str_to_slice",
                "bool" : "str_to_bool"}
        if type_ in knowns:
            return knowns[type_],opt
    raise ValueError("unknown type '%s'"%typestr)


def createwrapper(mod,func_name,func,name):
    """

    mod : str
        the module name
    func_name : str
        the function name
    func : function
        The function itself
    name : str
        The filename
    """
    filename=name
    with open(name,"w") as f:
        def fprint(*args,end="\n"):
            out=""
            sep=""
            for i in args:
                out+=sep+i
                sep=" "
            out+=end
            f.write(out)

        fprint("""#!/usr/bin/env python3
# PYTHON_ARGCOMPLETE_OK

import argparse
try:
    import argcomplete
except ImportError:
    argcomplete=None

""")
        doc=True
        para=False
        ret=False
        docs=""
        arg_type={}
        arg_opt={}
        arg_help={}

        paraopen=False
        for bla in func.__doc__.splitlines():
            blas=bla.strip()
            if doc:
                if "Parameters" == blas:
                    doc=False
                    para=True
                    off=0
                    while bla[off] == " ":
                        off+=1
            if doc and blas:
                if docs:
                    docs+=" "
                docs+=blas
            if para:
                if blas == "Return":
                    ret = True
                    para = False
                elif blas == "":
                    pass
                elif blas == "Parameters" or blas == "----------":
                    pass
                elif len(bla) > off and bla[off] != " ":
                    arg=blas.split(" : ")
                    curarg = arg[0]

                    t, o = parse_type(arg[1])
                    arg_type[curarg] = t
                    arg_opt[curarg] = o
                else:
                    if curarg in arg_help:
                        arg_help[curarg].append(esc(" "+blas))
                    else:
                        arg_help[curarg]=[]
                        arg_help[curarg].append(esc(blas))
        # Print functions that are needed
        if "str_to_slice" in arg_type.values():
            fprint("""
def str_to_slice(sstr):
    args=[]
    for s in sstr.split(','):
        args.append(int(s))
    print(args)
    return slice(*args)
""")
        if "str_to_bool" in arg_type.values():
            fprint("""
def str_to_bool(sstr):
    low=sstr.lower()
    # no or false
    if low.startswith("n") or low.startswith("f"):
        return False
    # yes or true
    elif low.startswith("y") or low.startswith("t"):
        return True
    else:
        raise ArgumentTypeError("Cannot parse %s to bool type"%sstr)
""")
        # Create the parser
        fprint("parser = argparse.ArgumentParser(%s)"%(esc(docs)))
        spec=inspect.signature(func)
        for name in spec.parameters:
            item=spec.parameters[name]
            #print(item)
            t=arg_type[name]
            if item.default == inspect._empty:
                # No default value present. Make positional argument
                fprint("parser.add_argument(\"%s\", type=%s"%(name,t),end="")
            else:
                if t == "str_to_bool" and item.default == False:
                    fprint("parser.add_argument(\"--%s\", action=\"store_true\", default=False"%(name),end="")
                else:
                    fprint("parser.add_argument(\"--%s\", type=%s"%(name,t),end="")
                    fprint(", default=",end="")
                    if type(item.default) == str:
                        fprint(esc(item.default),end="")
                    else:
                        fprint(str(item.default),end="")
                    if t == "str_to_slice":
                        if not name in arg_help:
                            arg_help[name]=[]
                        arg_help[name].append(esc("Pass slice object as comma separated list without spaces."))
            if item.name in arg_help:
                pre="\n    ,help="
                for helps in arg_help[item.name]:
                    fprint(pre,helps,end="")
                    pre="\n    "
            fprint(")")

        fprint("""
if argcomplete:
    argcomplete.autocomplete(parser)

# late import for faster auto-complete""")
        fprint("from %s import %s"%(mod,func_name))
        fprint("""
args = parser.parse_args()

# Call the function %s, using command line arguments
%s(**args.__dict__)

"""%(func_name,func_name))
    # alternative, but I think 0o755 is easier to read
    #import stat
    #os.chmod(filename,stat.S_IRWXU|stat.S_IRGRP|stat.S_IXGRP|stat.S_IROTH|stat.S_IXOTH)
    os.chmod(filename, 0o755)

if __name__ == "__main__":
    realfun=None
    if len(sys.argv) < 3:
        print("Please wait, scanning modules ...")
        x=boutmodules(["boutcore",
                       "boutdata",
                       "bout_runners",
                       "boututils",
                       "post_bout",
                       "zoidberg"])
        readline.set_completer(completer(x.modules).complete)
        mod=input("Module: ")
        x.scanFunctions(mod)
        readline.set_completer(completer(x.functions).complete)
        fun=input("Function: ")
        realfun=x.funcs[fun]
    else:
        mod=sys.argv[1]
        fun=sys.argv[2]
    if len(sys.argv) < 4:
        name="bin/bout-%s-%s"%(mod,fun)
        name=name.replace(".","-")
        readline.set_completer()
    else:
        name=sys.argv[3]
    while True:
        if len(sys.argv) > 4:
            if sys.argv[4] == "-f":
                break
        readline.set_startup_hook(lambda: readline.insert_text(name))
        name=input("Filename: ")
        readline.set_startup_hook()
        if os.path.exists(name):
            print("File %s already exists. If you continue it will be overwritten."%name)
            doit=input("Do you want to continue? [y/N] ")
            if "y" in doit.lower():
                break
        else:
            break
    #print("Creating wrapper; rerun with %s %s %s %s"%(sys.argv[0],mod,fun,name))
    if realfun is None:
        exec("from %s import %s as realfun"%(mod,fun))
    try:
        createwrapper(mod,fun,realfun,name)
    except:
        print("Creating failed. To rerun and overwrite the file without asking run:")
        print("%s %s %s %s -f"%(sys.argv[0],mod,fun,name))
        raise
