#!/usr/libexec/platform-python

# Copyright © 2024, Oracle and/or its affiliates.  All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import argparse
import fcntl
import glob
import json
import os
import re
import subprocess
import sys
import time

if os.path.exists('/etc/oci-fss-utils.d/prefix.txt') is False:
    sys.exit(0)

with open('/etc/oci-fss-utils.d/prefix.txt', encoding="utf-8") as fp:
    PREFIX = fp.readline().strip()

if re.match("^[a-zA-Z0-9:]+::/64$", PREFIX):
    PROTOCOL = "ipv6"
elif re.match("^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.0/24$", PREFIX):
    PROTOCOL = "ipv4"
else:
    raise ValueError("Invalid prefix in /etc/oci-fss-utils.d/prefix.txt.")

if PROTOCOL == "ipv6":
    PREFIX = PREFIX.replace("::/64", "")
    MOUNT_PATTERN = "^\[%s::([0-9a-z]{1,4})\]" % PREFIX
elif PROTOCOL == "ipv4":
    PREFIX = PREFIX.replace(".0/24", "")
    MOUNT_PATTERN = "^%s.([0-9]{1,3})" % PREFIX
else:
    raise ValueError("Invalid prefix in /etc/oci-fss-utils.d/prefix.txt.")

MOUNT_PATTERN_LEGACY = r"^192\.168\.([0-9]+)\.2:/"


class MountLock:
    def __init__(self):
        self.fd = open("/run/oci-fss-mount.lck", "w+")

    def __del__(self):
        fcntl.flock(self.fd, fcntl.LOCK_UN)

    def acquire(self):
        fcntl.flock(self.fd, fcntl.LOCK_EX)


def run_cmd(cmd):
    print(cmd)
    return subprocess.run(cmd.split(' '))


def find_mounts_in_mounts_file(fp, re_mount, re_legacy, mounts, mounts_legacy, slots):
    for line in fp:
        words = line.strip().split(" ")
        m = re_mount.match(words[0])
        if m and PROTOCOL == "ipv6":
            slot = int(m.group(1), 16)
            mounts.add(slot)
            slots[slot] = m[0]
            continue
        if m and PROTOCOL == "ipv4":
            slot = int(m.group(1))
            mounts.add(slot)
            slots[slot] = m[0]
            continue
        m = re_legacy.match(words[0])
        if m:
            slot = int(m.group(1))
            mounts_legacy.add(slot)
            slots[slot] = m[0]
            continue


def dump_mounts():
    mounts = set()
    mounts_legacy = set()
    slots = dict()

    re_pid = re.compile("^[0-9]+$")
    re_mount = re.compile(MOUNT_PATTERN)
    re_mount_legacy = re.compile(MOUNT_PATTERN_LEGACY)

    # Check every process because each process may be in its own mount namespace...
    for p in os.listdir("/proc"):
        if re_pid.match(p):
            try:
                with open(f"/proc/{p}/mounts", encoding="utf-8") as fp:
                    find_mounts_in_mounts_file(
                        fp, re_mount, re_mount_legacy, mounts, mounts_legacy, slots)
            except (FileNotFoundError, OSError):
                pass

    # Look in root mount namespace...
    with open("/proc/mounts", encoding="utf-8") as fp:
        find_mounts_in_mounts_file(
            fp, re_mount, re_mount_legacy, mounts, mounts_legacy, slots)

    return mounts, mounts_legacy, slots


def do_gc():
    lck = MountLock()
    lck.acquire()

    mounts, mounts_legacy, slots = dump_mounts()
    gc_count = 0
    live_count = 0

    # Mount targets follow this scheme: oci-fss-{i:04}
    for i in range(1, 255):
        if i in mounts:
            print(f"slot {i} has an associated mount...")
            live_count += 1
        elif i in mounts_legacy:
            print(f"slot {i} has an associated legacy mount...")
            live_count += 1
        else:
            slot_file = f"/run/oci-fss-utils.d/slot-{i}.txt"
            if os.path.isfile(slot_file):

                if i not in slots:
                    if PROTOCOL == "ipv6":
                        addr = f"[{PREFIX}::{i:x}]"
                    elif PROTOCOL == "ipv4":
                        addr = f"{PREFIX}.{i}"
                    else:
                        addr = None

                    if addr is not None:
                        addr_file = f"/run/oci-fss-utils.d/addr-{addr}.json"
                        if os.path.isfile(addr_file):
                            addr_file = f"/run/oci-fss-utils.d/addr-{addr}.json"
                            os.unlink(addr_file)
                            print(f"deleted {addr_file}")
                        else:
                            print(f"no {addr_file} to delete???")

                os.unlink(slot_file)
                print(f"deleted {slot_file}")

                print(f"removed mount on slot {i:04}...")
                gc_count += 1

    if gc_count > 0:
        mt_files = dict()

        # Scan all the addr files and rebuild the mt files...
        for addr_file in glob.glob("/run/oci-fss-utils.d/addr-*.json"):
            data = None
            with open(addr_file, "r") as fp:
                data = json.load(fp)
            mt_file_name = f"/run/oci-fss-utils.d/mt-{data['mount-target']}.txt"
            addr = addr_file[26:][:-5]
            # After a live update of duplicate forwarders, this may be
            # overwritten multiple times, which is ok...
            mt_files[mt_file_name] = addr

        for mt_file_name in mt_files.keys():
            addr = mt_files[mt_file_name]
            write_file = True
            if os.path.exists(mt_file_name):
                with open(mt_file_name, "r") as fp:
                    # This is how mount.oci-fss reads this file...
                    if addr == fp.readlines()[0].strip():
                        write_file = False

            if write_file:
                with open(mt_file_name, "w") as fp:
                    print(f"updating {mt_file_name} with {addr}.")
                    fp.write(f"{addr}\n")
            else:
                print(f"{mt_file_name} already contains {addr}.")

        for mt_file in glob.glob("/run/oci-fss-utils.d/mt-*.txt"):
            if mt_file not in mt_files:
                os.unlink(mt_file)
                print(f"deleted {mt_file}")

    print(f"live count: {live_count}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='oci-fss-gc',
        usage="oci-fss-gc [options]",
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-d', '--daemon', required=False, action='store_true')
    args = parser.parse_args()

    while args.daemon:
        time.sleep(300)
        do_gc()
    do_gc()
