import gzip
import os
import shlex
import shutil

configfile: "config.json"
output_dir=config["output"]
output_bam=""
if config["mapped_name"]:
    output_bam = "intermediate_data/"+config["mapped_name"]+"."+str(config["subsample"])+".bam"

rule subsample_map:
    input:
#        config["mapped"],
        'programs.checked'
    priority: 1
    output:
        output_bam
    run:
        mapped = config["mapped"]
        if config["subsample"] == "full":
#            shutil.copy2(mapped, output_bam)
            shell("cp "+mapped+" "+output_bam)
        else:
            cmd = "samtools view -c -F 260 " + config["mapped"] 
            proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)        
            out, err = proc.communicate()
            exitcode = proc.returncode
            if int(exitcode) != 0:
                print ("exitcode: "+exitcode+"\nerror message: "+str(err))
            
            nb_reads = int(str(out.decode("utf-8")).split("\n")[0])
            propotion = (int(config["subsample"])*100)/nb_reads
            print (str(config["subsample"])+" reads correspond to "+str(propotion)+" % of mapped reads")
            if propotion > 100:
                shell("cp "+mapped+" "+output_bam)
            else:
                shell("samtools view -s "+str(propotion)+" -f bam -b "+mapped+" > "+output_bam )

def get_reads():
    if config["reads"]:
        return config["reads"]
    else:
        return "no_reads_provided"

rule subsample_fastq:
    input:
        #get_reads,
        expand("{read}", read=config["reads"]),
        'programs.checked'
    priority: 1
    output:
        expand("intermediate_data/{read}."+str(config["subsample"])+".fastq.gz", read=config["reads_names"])
    run:
        no_subsample = config["subsample"]
        def check_existence(file):
            """
            Checks that file exists and that it can be read.
            """
            try:
                with open(file,"r") as f:
                    pass
            except IOError:
                print("Could not find file '"+file+"'. Make sure it exists.")
                sys.exit()

        def change_header_format(read,num):
            """
            Function that reformats the headers of a fastq file into usable format:
                Wrong header -> readID/num  , where num=1or2
            """
            num=str(num)
            tmp=output_dir+"intermediate_data/tmp.gz"
            fixed_header=read
            zip_command="cat"
            if read.endswith(".gz"):
                zip_command="gzip -cd"
            else:
                fixed_header=read+".gz"

            fline=get_first_line(read)
            if re.match("^@.+\.\d+\.\d ",fline):
                os.system(zip_command+" "+read+" | sed '/^@/s/\./:/g; /^@/s/:[0-9] .*/\/"+num+"/g' | gzip > "+tmp)

            elif re.match("^@.+ ",fline) and "." in fline:
                os.system(zip_command+" "+read+" | sed '/^@/s/\./:/g; /^@/s/ .*/\/"+num+"/g' | gzip > "+tmp)

            else:
                os.system(zip_command+" "+read+" | sed '/^@/s/ .*/\/"+num+"/g' | gzip > "+tmp)

            os.system("mv "+tmp+" "+fixed_header)
            return fixed_header

        def get_first_line(read):
            """
            Function for getting the first line of a compressed or uncompressed file.
            """
            #print(read)
            if read.endswith(".gz"):
                with gzip.open(read,"rt") as f:
                    fline=f.readline()
            else:
                with open(read,"r") as f:
                    fline=f.readline()
            return fline

        def check_readfile_extension(read):
            """
            Checks that a read file has either '.fq' or '.fastq' as file extension.
            """
            if not any(x in read for x in [".fq",".fastq"]):
                print("Error. Cannot find .fq or .fastq extension in read file --read "+read+".")
                sys.exit()


        def subsample(read,num):
            """
            Selects the num first reads of a fastq files (subsamples) and copies them
            to a subsampled file.
            The subsampled file is used throughout the analysis instead of original file.
            This makes the run faster and also protects the original read file from being
            modified.
            """
            header_flag = True
            readname=re.split(".fq|.fastq",os.path.basename(read))[0]
            subsampled=output_dir+"intermediate_data/"+readname+"."+str(num)+".fastq.gz"
            if os.path.exists(subsampled):
                print("Already found file with "+str(num)+" subsampled reads.")
                print("Subsampled file "+subsampled+" will be used for analysis.")
                header_flag = False
            else:
                if num == "full":
                    # copy past as it is
                    if read.endswith((".gz",".gzip")):
                       shutil.copy2(read, subsampled)
                    # copy past and compress on the fly
                    else :
                        with open(read, 'rb') as f_in:
                            try:
                                with gzip.open( subsampled, 'wb') as f_out:
                                    f_out.writelines(f_in)
                            except IOError:
                                print("Could not write file '"+subsampled+"'. Make sure it exists.")
                                sys.exit()
                else:
                    print("Subsampling "+str(num)+" reads of "+read+" to new read file: "+subsampled)
                    # Multiply with 4 because a read consists of 4 lines in a fastq file.
                    num=str(4*num)
                    zip_command="cat"
                    if read.endswith((".gz","gzip")):
                        zip_command="zcat <"

                    # gzip gives broken pipe for compressed files, but works anyway.
                    os.system(zip_command+" "+read+" | head -"+num+" | gzip > "+subsampled)
            return header_flag, subsampled


        def check_single_reads(read):

            # --- Inner functions ---

            def is_interleaved(read,pattern1,pattern2):
                """
                Function for checking if a single read file is interleaved.
                If both pattern1 and pattern2 can be found in the first 10000 lines,
                True is returned. Else False.
                If two read headers are identical, True is returned. Else false.
                """
                if read.endswith(".gz"):
                    f=gzip.open(read,"rt")
                else:
                    f=open(read,"r")

                num_read1=0
                num_read2=0
                counter=0
                headers=[]

                for line in f:
                    # Only look at read headers. They come every fourth line.
                    if counter%4==0:
                        if re.match(pattern1,line):
                            num_read1+=1
                        elif re.match(pattern2,line):
                            num_read2+=1

                        if line.startswith("@"):
                            ID=line.split(" ")[0]
                            if ID in headers:
                                return True
                            else:
                                headers.append(ID)
                    counter+=1
                    if counter==10000:
                        break
                f.close()

                if num_read1>0 and num_read2>0:
                    return True
                else:
                    return False


            def deinterleave(read,line):
                """
                Function for deinterleaving reads into two separate files.
                It first checks for old and new Illumina fastq headers in
                order to split files. Otherwise, it assumes that the reads
                are alternating and therefore split all odd reads into one
                file and all even reads into another file.
                """
                print("Found interleaved reads in read file: "+read+".")

                readname=re.split(".fq|.fastq",os.path.basename(read))[0]
                deinterleaved=output_dir+"intermediate_data/deinterleaved."+readname+"."
                deinterleaved1=deinterleaved+"left.fastq.gz"
                deinterleaved2=deinterleaved+"right.fastq.gz"

                if os.path.exists(deinterleaved1) and os.path.exists(deinterleaved2):
                    print("Already found existing deinterleaved files for read file: "+read)
                    print("Continuing run with '"+deinterleaved1+"' and '"+deinterleaved2+"'.")
                    return deinterleaved1,deinterleaved2


                print("Deinterleaving data.")
                zip_command="cat"
                if read.endswith(".gz"):
                    zip_command="gzip -cd"

                # Check for Illumina(1.8+) or Illumina(1.8-) fastq format.
                if re.match("^@.+ 1:", line) and re.match("^@.+ 2:",line):
                    print("Extracting left reads of '"+read+"' to '"+deinterleaved1+"'.")
                    os.system(zip_command+" "+read+" | awk '/^@.*[ ]1:/{c=4} c&&c--' | gzip --stdout > "+deinterleaved1)
                    print("Extracting right reads of '"+read+"' to '"+deinterleaved1+"'.")
                    os.system(zip_command+" "+read+" | awk '/^@.*[ ]2:/{c=4} c&&c--' | gzip --stdout > "+deinterleaved2)
                elif re.match("^@.+/1",line) and re.match("^@.+/2",line):
                    print("Extracting left reads of '"+read+"' to '"+deinterleaved1+"'.")
                    os.system(zip_command+" "+read+" | awk '/^@.*[/]1/{c=4} c&&c--' | gzip --stdout > "+deinterleaved1)
                    print("Extracting right reads of '"+read+"' to '"+deinterleaved2+"'.")
                    os.system(zip_command+" "+read+" | awk '/^@.*[/]2/{c=4} c&&c--' | gzip --stdout > "+deinterleaved2)
                else:
                    # Assuming that interleaved reads in a file are alternating.
                    # This extraction can be written with a one line command:
                    #   os.system(paste - - - - - - - - <"""+read+""" | tee >(cut -f 1-4 | tr "\t" "\n" > """+deinterleaved1+""") | cut -f 5-8 | tr "\t" "\n" > """+deinterleaved2)
                    # The command works if you execute it from the command line, but not using os.system.
                    # Right now, two lines are used, and this part can be optimized.

                    print("Extracting left reads of '"+read+"' to '"+deinterleaved1+"'.")
                    os.system(zip_command+' '+read+' | paste - - - - - - - - | cut -f 1-4 | tr "\t" "\n" | gzip > '+deinterleaved1)
                    print("Extracting right reads of '"+read+"' to '"+deinterleaved2+"'.")
                    os.system(zip_command+' '+read+' | paste - - - - - - - - | cut -f 5-8 | tr "\t" "\n" | gzip > '+deinterleaved2)


                return deinterleaved1,deinterleaved2

            # --- Outer function ---

            # Check that read files exists and can be read
            check_existence(read)

            # subsample reads from read file and continue analysis on new, subsampled file.
            read=subsample(read,no_subsample)[1]
            # Get first first line of read file
            fline=get_first_line(read)

            # Check if read file is interleaved.
            # If so we need to split reads (deinterleave) into two files.
            if is_interleaved(read,"^@.+/1","^@.+/2") or is_interleaved(read,"^@.+ 1:","^@.+ 2:"):
                # Deinterleaved reads
                read1,read2=deinterleave(read,fline)
                if any(symbol in get_first_line(read1) for symbol in [" ","_","."]) or any(symbol in get_first_line(read2) for symbol in [" ","_","."]):
                    # Reformat headers so that no whitespaces underscores or punctutiations are present.
                    # NEEDED since pysam and Trinity cannot handle that format.
                    # Convertion is: 'Wrong header format' -> 'readID/pair#'        (pair# = 1 or 2 depending on the mate)
                    print("Reformatting headers of: "+read1)
                    read1=change_header_format(read1,1)
                    print("Reformatting headers of: "+read2)
                    read2=change_header_format(read2,2)
                return [read1,read2]

            else:
                # Single read file is not interleaved.
                if any(symbol in fline for symbol in [" ","_","."]):
                    # Reformat headers so that no whitespaces underscores or punctutiations are present.
                    # NEEDED since pysam and Trinity cannot handle that format.
                    # Convertion is: 'Wrong header format' -> 'readID/pair#'        (pair# = 1 or 2 depending on the mate)
                    print("Reformating read headers in fastq file: "+read)
                    read=change_header_format(read,1)

                return [read]


        def check_paired_reads(read1,read2):
            """
            Checks that read files exist and are in correct format.
            Updates format of read headers if whitespaces, underscores or punctuations are present.
            """

            check_existence(read1)
            check_existence(read2)

            check_readfile_extension(read1)
            check_readfile_extension(read2)
            # Subsample reads from read file and continue analysis on new, subsampled files.
            # Need to divide with 2 since subsample takes the total amount of reads that will be used.
            no_subsample=config["subsample"]
            header_flag1, read1=subsample(read1,no_subsample)
            header_flag2, read2=subsample(read2,no_subsample)
            fline1=get_first_line(read1)
            fline2=get_first_line(read2)

            # Check format.
            # Either Illumina1.8+: @<instrument>:<run number>:<flowcell ID>:<lane>:<tile>:<xpos>:<y-pos> <read>:<is filtered>:<control number>:<index>
            # Or Illumina1.8-: @<machine_id>:<lane>:<tile>:<x_coord>:<y_coord>#<index>/<read>
            format="Unknown"

            if re.match("^@.+/1",fline1) and re.match("^@.+/2",fline2):
                format="Old Illumina (Illumina1.8-)"
                pass
            elif re.match("^@.+/2",fline1) and re.match("^@.+/1",fline2):
                read1,read2=read2,read1
                format="Old Illumina (Illumina1.8-)"

            elif re.match("^@.*\ 1:",fline1) and re.match("^@.*\ 2:",fline2):
                format="New Illumina (Illumina1.8+)"
                pass
            elif re.match("^@.*\ 2:",fline1) and re.match("^@.*\ 1:",fline2):
                format="New Illumina (Illumina1.8+)"
                read1,read2=read2,read1

            print("Based on read headers we are working with format: "+format)

            if format=="Unknown":
                print("Unknown read header format.")
                print("First line in :"+read1+"\n\t"+fline1)
                print("First line in :"+read2+"\n\t"+fline2)
                print("Assuming left read file: "+ read1+ " and right read file: "+read2)


            if header_flag1 and header_flag2:
                if " " in fline1 or " " in fline2 or "." in fline1 or "." in fline2 or "_" in fline1 or "_" in fline2:
                    # Reformat headers so that no whitespaces underscores or punctutiations are present.
                    # NEEDED since pysam and Trinity cannot handle that format.
                    # Convertion is: 'Wrong header format' -> 'readID/pair#'        (pair# = 1 or 2 depending on the mate)
                    print("Reformatting headers of: "+read1)
                    read1=change_header_format(read1,1)
                    print("Reformatting headers of: "+read2)
                    read2=change_header_format(read2,2)
            else:
                print("No header reformatting needed, continuing")

            return read1,read2

        # Check if one read file is provided.
        if len(config["reads"])==1:
            # Check that single reads are OK
            # or deinterleave paired end reads.
            check_single_reads(config["reads"][0])
        # Check if two read files are provided.
        elif len(config["reads"])==2:
            # Check that paired end reads are OK.
            check_paired_reads(config["reads"][0],config["reads"][1])
