#!/usr/bin/python
#
# Copyright (C) 2009-2012 Red Hat, Inc.
# Authors:
# Thomas Woerner <twoerner@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import getopt
import dbus

from firewall.client import FirewallClient
from firewall.errors import *

def usage():
    print("Usage: %s { HELP | STATUS | PANIC | ZONE | MODE | DIRECT }" % sys.argv[0])
    print(
"""
 HELP   := { -h | --help | 
             -v | --version }
 STATUS := { --state |
             --reload |
             --complete-reload }
 PANIC  := { --enable | --disable | --query } --panic
 ZONE   := { --get-default-zone |
             --set-default-zone=<zone> |
             --get-zones |
             --get-active-zones |
             --get-zone-of-interface=<interface> }
 MODE   := { [ --zone=<zone> ] {
               { { --add [--timeout=<seconds>] | --remove | --query } ACTION } |
               { { --enable [--timeout=<seconds>] } | --disable | --query }
                 --masquerade } |
               --list={ service | port | interface | icmp-block |
                 forward-port | all } } | 
             --get-services |
             --get-icmptypes }
 ACTION := { --service=<service> |
             --port=<port>[-<port>]/<protocol> |
             --interface=<interface> |
             --icmp-block=<icmptype> |
             --forward-port=port=<port>[-<port>]:proto=<protocol> { 
               :toport=<port>[-<port>] | :toaddr=<address> | 
               :toport=<port>[-<port>]:toaddr=<address> } }
 DIRECT := --direct { ipv4 | ipv6 | eb } {
             {--passthrough <args> } |
             { { --add-chain | --remove-chain | --query-chain } <table> 
               <chain> } |
             { --get-chains <table> } |
             { --add-rule <table> <chain> <priority> <args> } |
             { { --remove-rule | --query-rule } <table> <chain> <args> } |
             { --get-rules <table> <chain> } }
""")

def __fail(msg=None):
    if msg:
        print(msg)
#    usage()
    sys.exit(2)

if "--direct" not in sys.argv[1:]:
    try:
        (opts, args) = \
            getopt.getopt(sys.argv[1:], "hv", 
                          [ "help", "version", "timeout=", "reload",
                            "complete-reload", "state", 
                            "get-default-zone", "set-default-zone=",
                            "get-zones", "get-active-zones",
                            "get-zone-of-interface=",
                            # modes (exactly one of those)
                            "add", "change", "remove","enable", "disable",
                            "query", "list=", "get-services", "get-icmptypes",
                            # zone
                            "zone=",
                            # actions (exactly one of those)
                            "panic",
                            "interface=",
                            "service=", "port=", "masquerade",
                            "forward-port=", "icmp-block=",
                            ])
    except Exception, msg:
        print(msg)
        usage()
        sys.exit(1)

    if not opts:
        usage()
        sys.exit(1)
else:
    opts = [ ]
    args = sys.argv[1:]

timeout = 0
mode = None
action = None
value = None
zone = ""
interface = None

if len(args) > 2 and args[0] == "--direct":
    direct_ipv = args[2]
    if args[1] == "--passthrough" and len(args) > 3:
        mode = args[1][2:]
        direct_args = args[3:]
    elif args[1] in [ "--add-chain", "--remove-chain", "--query-chain" ] and \
            len(args) == 5:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]
    elif args[1] == "--get-chains" and len(args) == 4:
        mode = args[1][2:]
        direct_table = args[3]
    elif args[1] == "--add-rule" and len(args) > 7:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]
        direct_priority = int(args[5])
        direct_args = args[6:]
    elif args[1] in [ "--remove-rule", "--query-rule" ] and \
            len(args) > 6:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]
        direct_args = args[5:]
    elif args[1] == "--get-rules" and len(args) == 5:
        mode = args[1][2:]
        direct_table = args[3]
        direct_chain = args[4]        

for (opt, val) in opts:
    if opt in ["-h", "--help"]:
        usage()
        sys.exit(0)
    elif opt in ["-v", "--version"]:
        if mode:
            __fail()
        mode = "version"

    elif opt in [ "--reload", "--complete-reload", "--state" ]:
        if mode:
            __fail()
        mode = opt[2:]

    elif opt in [ "--get-default-zone", "--get-zones", "--get-active-zones" ]:
        if mode:
            __fail()
        mode = opt[2:]

    elif opt in [ "--set-default-zone", "--get-zone-of-interface" ]:
        if mode:
            __fail()
        mode = opt[2:]
        value = val

    # timeout
    elif opt == "--timeout":
        try:
            timeout = int(val)
        except Exception, msg:
            usage()
            sys.exit(2)
        if timeout < 1:
            __fail("Timeout not valid")

    # mode
    elif opt in [ "--enable", "--disable", "--query", "--add", "--change", "--remove" ]:
        if mode:
            __fail()
        mode = opt[2:]
    elif opt == "--list":
        if mode or action:
            __fail()
        mode = opt[2:]
        action = val
    elif opt in [ "--get-services", "--get-icmptypes"]:
        mode = opt[2:]
    # zone
    elif opt == "--zone":
        if zone:
            __fail()
        zone = val

    # action
    elif opt in [ "--panic", "--interface",
                  "--service", "--port", "--masquerade",
                  "--forward-port", "--icmp-block" ]:
        if action:
            __fail()
        action = opt[2:]
        
        if opt not in [ "--panic", "--masquerade" ]:
            if value:
                __fail()
            value = val

if not mode:
    __fail("No mode.")
if mode not in [ "version", "reload", "complete-reload", "state",
                 "get-default-zone", "set-default-zone", "get-zones",
                 "get-active-zones", "get-zone-of-interface",
                 "passthrough", "add-chain", "remove-chain", "query-chain",
                 "get-chains", "add-rule", "remove-rule", "query-rule",
                 "get-rules", "get-services", "get-icmptypes" ]:
    if not action:
        __fail("No action.")
#    if not zone and not (mode == "list" and action == "zone"):
#        __fail("No zone.")
    if action not in [ "panic", "masquerade"] and mode != "list" and not value:
        __fail("No value.")

if action in [ "interface", "service", "port", "forward-port", "icmp-block" ]:
    if mode not in [ "add", "change", "remove", "query", "list" ]:
        __fail(_("Wrong action and mode combination"))
elif action == "masquerade" and mode not in [ "enable", "disable", "query" ]:
    __fail(_("Wrong action and mode combination"))

if timeout != 0:
    if mode != "add" and mode != "enable":
        __fail("Timeout only valid in enable or add mode.")
    if action == "panic":
        __fail("No timeout for panic.")

#print("ZONE='%s', ACTION='%s', MODE='%s'" % (zone, action, mode))

try:
    fw = FirewallClient()

    if mode == "version":
        print(fw.get_property("version"))
        sys.exit(0)
    elif mode == "state":
        state = fw.get_property("state")
        if state != "RUNNING":
            sys.exit(-1)
    elif mode == "reload":
        if not fw.reload():
            sys.exit(1)
    elif mode == "complete-reload":
        fw.complete_reload()
    elif mode == "passthrough":
        print(fw.passthrough(direct_ipv, direct_args))
    elif mode == "add-chain":
        fw.addChain(direct_ipv, direct_table, direct_chain)
    elif mode == "remove-chain":
        fw.removeChain(direct_ipv, direct_table, direct_chain)
    elif mode == "query-chain":
        sys.exit(not fw.queryChain(direct_ipv, direct_table, direct_chain))
    elif mode == "get-chains":
        print(" ".join(fw.getChains(direct_ipv, direct_table)))
    elif mode == "add-rule":
        fw.addRule(direct_ipv, direct_table, direct_chain, direct_priority,
                   direct_args)
    elif mode == "remove-rule":
        fw.removeRule(direct_ipv, direct_table, direct_chain, direct_args)
    elif mode == "query-rule":
        sys.exit(not fw.queryRule(direct_ipv, direct_table, direct_chain,
                                  direct_args))
    elif mode == "get-rules":
        rules = fw.getRules(direct_ipv, direct_table, direct_chain)
        for rule in rules:
            print(" ".join(rule)) 
    elif mode == "get-default-zone":
        print(fw.getDefaultZone())
    elif mode == "set-default-zone":
        fw.setDefaultZone(value)
    elif mode == "get-zones":
        print(" ".join(fw.getZones()))
    elif mode == "get-active-zones":
        zones = fw.getActiveZones()
        for zone in zones:
            print("%s: %s" % (zone, " ".join(zones[zone])))
    elif mode == "get-zone-of-interface":
        try:
            print(fw.getZoneOfInterface(value))
        except:
            pass
    elif mode == "get-services":
        l = fw.listServices()
        if len(l) > 0:
            print(" ".join(l))
    elif mode == "get-icmptypes":
        l = fw.listIcmpTypes()
        if len(l) > 0:
            print(" ".join(l))
    else:
        # panic
        if action == "panic":
            if mode == "enable":
                fw.enablePanicMode()
            elif mode == "disable":
                fw.disablePanicMode()
            elif mode == "query":
                sys.exit(not fw.queryPanicMode())

        # zone
        elif action == "zone":
            if mode == "list":
                l = fw.getZones()
                if len(l) > 0:
                    print(" ".join(l))

        # interface
        elif action == "interface":
            if mode == "list":
                l = fw.getInterfaces(zone)
                if len(l) > 0:
                    print(" ".join(l))
            elif mode == "add":
                fw.addInterface(zone, value)
            elif mode == "change":
                fw.changeZone(zone, value)
            elif mode == "remove":
                fw.removeInterface(zone, value)
            elif mode == "query":
                sys.exit(not fw.queryInterface(zone, value))

        # service
        elif action == "service":
            if mode == "list":
                l = fw.getServices(zone)
                if len(l) > 0:
                    print(" ".join(l))
            elif mode == "add":
                fw.addService(zone, value, timeout)
            elif mode == "remove":
                fw.removeService(zone, value)
            elif mode == "query":
                sys.exit(not fw.queryService(zone, value))

        # port
        elif action == "port":
            if mode == "list":
                l = fw.getPorts(zone)
                if len(l) > 0:
                    print(" ".join(["%s/%s" % (port[0], port[1]) for port in l]))
            else:
                try:
                    (port, proto) = value.split("/")
                except Exception, msg:
                    __fail(msg)

                if mode == "add":
                    fw.addPort(zone, port, proto, timeout)
                elif mode == "remove":
                    fw.removePort(zone, port, proto)
                elif mode == "query":
                    sys.exit(not fw.queryPort(zone, port, proto))

        # masquerade
        elif action == "masquerade":
            if mode == "enable":
                fw.enableMasquerade(zone, timeout)
            elif mode == "disable":
                fw.disableMasquerade(zone)
            elif mode == "query":
                sys.exit(not fw.queryMasquerade(zone))

        # forward port
        elif action == "forward-port":
            if mode == "list":
                l = fw.getForwardPorts(zone)
                if len(l) > 0:
                    print("\n".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in l]))
            else:
                port = None
                protocol = None
                toport = None
                toaddr = None
                args = value.split(":")
                for arg in args:
                    try:
                        (opt,val) = arg.split("=")
                        if opt == "port":
                            port = val
                        elif opt == "proto":
                            protocol = val
                        elif opt == "toport":
                            toport = val
                        elif opt == "toaddr":
                            toaddr = val
                    except:
                        __fail("invalid forward port arg '%s'" % (arg))
                if not port:
                    __fail("missing port")
                if not protocol:
                    __fail("missing protocol")
                if not (toport or toaddr):
                    __fail("missing destination")

                if mode == "add":
                    fw.addForwardPort(zone, port, protocol, toport, toaddr,
                                      timeout)
                elif mode == "remove":
                    fw.removeForwardPort(zone, port, protocol, toport, toaddr)
                elif mode == "query":
                    sys.exit(not fw.queryForwardPort(zone, port, protocol,
                                                     toport, toaddr))

        # block icmp
        elif action == "icmp-block":
            if mode == "list":
                l = fw.getIcmpBlocks(zone)
                if len(l) > 0:
                    print(" ".join(l))

            elif mode == "add":
                fw.addIcmpBlock(zone, value, timeout)
            elif mode == "remove":
                fw.removeIcmpBlock(zone, value)
            elif mode == "query":
                sys.exit(not fw.queryIcmpBlock(zone, value))

        elif (action == "all"):
            if mode == "list":
                print("zone: " + (zone if zone != "" else fw.getDefaultZone()))
                l = fw.getInterfaces(zone)
                if len(l) > 0:
                    print("interfaces: " + " ".join(l))
                l = fw.getServices(zone)
                if len(l) > 0:
                    print("services: " + " ".join(l))
                l = fw.getPorts(zone)
                if len(l) > 0:
                    print("ports: " + " ".join(["%s/%s" % (port[0], port[1]) for port in l]))
                l = fw.getForwardPorts(zone)
                if len(l) > 0:
                    print("forward-ports: " + "\n".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in l]))
                l = fw.getIcmpBlocks(zone)
                if len(l) > 0:
                    print("icmp-blocks: " + " ".join(l))

except dbus.DBusException, e:
    if e._dbus_error_name == 'org.freedesktop.DBus.Error.ServiceUnknown':
        print("FirewallD is probably not running.")
        sys.exit(NOT_RUNNING)
    elif "NotAuthorizedException" in e._dbus_error_name:
        print("Authorization failed.")
        sys.exit(NOT_AUTHORIZED)
    else:
        try:
            code = FirewallError.get_code(e.message)
        except:
            code = UNKNOWN_ERROR
            print("Error: %s" % e)
        else:
            print("Error: %s" % e.message)
        sys.exit(code)

sys.exit(0)
