// Package sip is a very basic (and incomplete) implementation of SIP messaging protocol.
//
// Based on the sipgo library, we are only adding some helpers useful in the context of exploitation.
// For example, we prefer to avoid the complexity of concepts like transactions or dialogs.
//
// References:
// - https://github.com/emiago/sipgo
// - https://datatracker.ietf.org/doc/html/rfc3261
// - https://datatracker.ietf.org/doc/html/rfc3311
// - https://datatracker.ietf.org/doc/html/rfc3581
// - https://datatracker.ietf.org/doc/html/rfc3265
// - https://datatracker.ietf.org/doc/html/rfc7118
package sip

import (
	"bufio"
	_ "embed"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"
	"time"

	"github.com/emiago/sipgo/sip"
	"github.com/google/uuid"
	"github.com/vulncheck-oss/go-exploit/output"
	"github.com/vulncheck-oss/go-exploit/protocol"
	"github.com/vulncheck-oss/go-exploit/random"
)

const (
	// SIP over UDP message size should be lower than 1300 bytes.
	// https://datatracker.ietf.org/doc/html/rfc3261#section-18.1.1
	UDPMessageLength          = 1300
	errMsgRequiredParam       = "Required parameter: %s"
	DefaultMaxForwards        = 70
	DefaultCSeq               = 1
	DefaultInviteContentType  = "application/sdp"
	DefaultMessageContentType = "text/plain"
	DefaultMessageBodyContent = "Hello, this is a test message."
	DefaultInfoContentType    = "application/dtmf-relay"
	DefaultInfoBodyContent    = "Signal=1Signal=1\nDuration=100"
	DefaultExpiresHeader      = 3600
	DefaultRackCSeq           = 1
	DefaultRackInviteCSeq     = 314159
	ContentTypePidf           = "application/pidf+xml"
	// Even though the specification requires "\r\n", it is common to see
	// implementations using "\n" as a line ending.
	// https://datatracker.ietf.org/doc/html/rfc3261#section-7.5
	sipLineEnd    = "\r\n"
	sipLineEndAlt = "\n"
)

// GlobalUA is the default User-Agent for all SIP go-exploit comms.
//
//go:embed user-agent.txt
var GlobalUA string

// Supported transport protocol.
type TransportType int

const (
	UNKNOWN TransportType = iota
	UDP
	TCP
	TLS
	WS
	WSS
)

func (t TransportType) String() string {
	return [...]string{"UNKNOWN", "UDP", "TCP", "TLS", "WS", "WSS"}[t]
}

// Default server ports for each transport protocol.
var DefaultPorts = map[TransportType]int{
	UDP: sip.DefaultUdpPort,
	TCP: sip.DefaultTcpPort,
	TLS: sip.DefaultTlsPort,
	WS:  sip.DefaultWsPort,
	WSS: sip.DefaultWssPort,
}

// Sends a TCP/TLS message and returns the response.
func SendAndReceiveTCP(conn net.Conn, req sip.Message) (*sip.Response, bool) {
	if conn == nil {
		output.PrintfFrameworkError(errMsgRequiredParam, "conn")

		return nil, false
	}
	if req == nil {
		output.PrintfFrameworkError(errMsgRequiredParam, "msg")

		return nil, false
	}
	ok := protocol.TCPWrite(conn, []byte(req.String()))
	if !ok {
		output.PrintfFrameworkError("Writing message %s to the socket", req.String())

		return nil, false
	}
	// To discard requests coming from the server, like OPTIONS or NOTIFY.
	for {
		respMsg, ok := ReadMessageTCP(conn)
		if !ok {
			output.PrintfFrameworkError("Reading response from the socket")

			return nil, false
		}
		resp, ok := respMsg.(*sip.Response)
		if !ok {
			output.PrintfFrameworkDebug("Response is not a valid: %+v", respMsg)

			continue
		}

		return resp, true
	}
}

// Returns a SIP message from a TCP connection.
func ReadMessageTCP(conn net.Conn) (sip.Message, bool) {
	reader := bufio.NewReader(conn)
	var resp string
	for {
		line, err := reader.ReadString('\n')
		if err != nil {
			output.PrintfFrameworkError("Reading response from the socket: %s", err.Error())

			return nil, false
		}
		resp += line
		if line == sipLineEnd || line == sipLineEndAlt {
			break
		}
	}
	// First message parse to get the length of the body.
	msg, err := sip.ParseMessage([]byte(resp))
	if err != nil {
		output.PrintfFrameworkError("Parsing response %+v: %s", resp, err.Error())

		return nil, false
	}
	bodyLen, err := strconv.ParseInt(msg.ContentLength().Value(), 10, 64)
	if err != nil {
		output.PrintfFrameworkError("Parsing Content-Length header: %s", err.Error())
	}
	if bodyLen > 0 {
		body := make([]byte, bodyLen)
		count, err := io.ReadFull(reader, body)
		if err != nil {
			output.PrintfFrameworkError("Reading message body: %s", err.Error())

			return nil, false
		}
		output.PrintfFrameworkDebug("Read %d bytes of body", count)
		resp := resp + string(body)
		// Second parse, now with the body included.
		msg, err := sip.ParseMessage([]byte(resp))
		if err != nil {
			output.PrintFrameworkError("Parsing response %+v with body: %s", msg, err.Error())

			return nil, false
		}
	}

	return msg, true
}

// Sends a UDP message and returns the response.
//
// If 'amount' is set to 0, it will use the recommended size for UDP messages (1300 bytes).
func SendAndReceiveUDP(
	conn *net.UDPConn, req sip.Message,
) (*sip.Response, bool) {
	if conn == nil {
		output.PrintfFrameworkError(errMsgRequiredParam, "conn")

		return nil, false
	}
	if req == nil {
		output.PrintfFrameworkError(errMsgRequiredParam, "req")

		return nil, false
	}
	ok := protocol.UDPWrite(conn, []byte(req.String()))
	if !ok {
		output.PrintfFrameworkError("Writing message %s to the socket", req.String())

		return nil, false
	}
	// To discard requests coming from the server, like OPTIONS or NOTIFY.
	for {
		respMsg, ok := ReadMessageUDP(conn)
		if !ok {
			output.PrintFrameworkError("Reading response from the socket")

			return nil, false
		}
		resp, ok := respMsg.(*sip.Response)
		if !ok {
			output.PrintfFrameworkDebug("Response is not a valid: %+v", respMsg)

			continue
		}

		return resp, true
	}
}

// Returns a SIP message from a UDP connection using the message length
// defined in the RFC 3261.
// https://datatracker.ietf.org/doc/html/rfc3261#section-18.1.1
func ReadMessageUDP(conn *net.UDPConn) (sip.Message, bool) {
	resp := make([]byte, UDPMessageLength)
	count, err := conn.Read(resp)
	if err != nil {
		output.PrintFrameworkError("Failed to read from the socket: " + err.Error())

		return nil, false
	}
	resp = resp[:count]
	msg, err := sip.ParseMessage(resp)
	if err != nil {
		output.PrintfFrameworkError("Parsing response %+v: %s", resp, err.Error())

		return nil, false
	}

	return msg, true
}

// Returns a generic well-formed request. Useful to start the communication.
//
// Depending on the method (if not set in 'opts') required headers and body content
// is automatically added.
func NewSipRequest(
	method sip.RequestMethod, host string, opts *NewSipRequestOpts,
) (*sip.Request, bool) {
	if method == "" {
		output.PrintfFrameworkError(errMsgRequiredParam, "method")

		return nil, false
	}
	if host == "" {
		output.PrintfFrameworkError(errMsgRequiredParam, "host")

		return nil, false
	}
	if opts == nil {
		opts = &NewSipRequestOpts{}
	}
	lHost := host
	if opts.LocalHost != "" {
		lHost = opts.LocalHost
	}
	lPort := opts.Port
	if opts.LocalPort != 0 {
		lPort = opts.LocalPort
	}
	toURI := sip.Uri{
		User: opts.ToUser,
		Host: host,
		Port: opts.Port,
	}
	fromURI := sip.Uri{
		User:     opts.User,
		Password: opts.Password,
		Host:     lHost,
		Port:     lPort,
	}
	req := sip.NewRequest(method, toURI)
	if opts.Transport != 0 {
		req.SetTransport(opts.Transport.String())
	}
	callID, err := uuid.NewRandom()
	if err != nil {
		output.PrintfFrameworkError("Could not generate UUID: %s", err.Error())

		return nil, false
	}
	callIDHeader := sip.CallIDHeader(callID.String())
	viaOpts := NewViaOpts{
		Host:      lHost,
		Port:      lPort,
		Transport: req.Transport(),
	}
	viaHeader, ok := NewViaHeader(&viaOpts)
	if !ok {
		output.PrintfFrameworkError("Creating Via header with options %+v", viaOpts)

		return nil, false
	}
	fromHeader := &sip.FromHeader{
		Address: fromURI,
		Params:  map[string]string{"tag": sip.GenerateTagN(10)},
	}
	toHeader := &sip.ToHeader{Address: toURI}
	maxForwardsHeader := sip.MaxForwardsHeader(70)
	req.PrependHeader(
		viaHeader,
		&maxForwardsHeader,
		fromHeader,
		toHeader,
		&callIDHeader,
		&sip.CSeqHeader{SeqNo: uint32(DefaultCSeq), MethodName: method},
		sip.NewHeader("User-Agent", strings.TrimSpace(GlobalUA)),
	)
	if IsContactRequired(method) {
		req.AppendHeader(&sip.ContactHeader{
			Address: fromURI,
			Params:  sip.HeaderParams{"transport": req.Transport()},
		})
	}
	AddMethodHeaders(req, method)
	body, contentType := NewRequestBody(method)
	bodyLen := len(body)
	if bodyLen > 0 {
		contentTypeHeader := sip.ContentTypeHeader(contentType)
		req.AppendHeader(&contentTypeHeader)
		req.SetBody([]byte(body))
		// The Content-Length header is automatically added by sipgo.
	} else {
		contentLengthHeader := sip.ContentLengthHeader(0)
		req.AppendHeader(&contentLengthHeader)
	}

	return req, true
}

// Optional parameters to create a request.
type NewSipRequestOpts struct {
	// Default: 0. The host could be a domain name.
	Port int
	// Default: 'host' function parameter.
	LocalHost string
	// Default: 'Port' property.
	LocalPort int
	// Default: No user set in the request (Via and To headers).
	ToUser string
	// Default: No user set in the request (From header).
	User string
	// Default: No password set in the request (From header).
	Password string
	// Default: UDP.
	Transport TransportType
}

// Returns a 'Via' header for a SIP request.
func NewViaHeader(opts *NewViaOpts) (*sip.ViaHeader, bool) {
	host := "localhost"
	transport := sip.TransportUDP
	protocolName := "SIP"
	protocolVersion := "2.0"
	if opts == nil {
		opts = &NewViaOpts{}
	}
	if opts.ProtocolName != "" {
		protocolName = opts.ProtocolName
	}
	if opts.ProtocolVersion != "" {
		protocolVersion = opts.ProtocolVersion
	}
	if opts.Transport != "" {
		transport = opts.Transport
	}
	if opts.Host != "" {
		host = opts.Host
	}

	// Port could be zero. For example if host is a domain name.
	return &sip.ViaHeader{
		ProtocolName:    protocolName,
		ProtocolVersion: protocolVersion,
		Transport:       transport,
		Host:            host,
		Port:            opts.Port,
		Params: sip.HeaderParams{
			"branch": sip.GenerateBranchN(16),
			"rport":  "",
		},
	}, true
}

// Optional parameters to create a Via header.
type NewViaOpts struct {
	// Default: "SIP"
	ProtocolName string
	// Default: "2.0"
	ProtocolVersion string
	// Default: "UDP"
	Transport string
	// Default: "localhost"
	Host string
	// Default: 0
	Port int
}

// Checks if contact header is required for a method.
func IsContactRequired(method sip.RequestMethod) bool {
	switch method {
	case sip.REGISTER, sip.INVITE, sip.SUBSCRIBE, sip.REFER, sip.PUBLISH, sip.NOTIFY:

		return true
	case sip.ACK, sip.CANCEL, sip.BYE, sip.OPTIONS, sip.INFO, sip.PRACK, sip.UPDATE, sip.MESSAGE:

		return false
	default:

		return false
	}
}

// Adds generic method-specific headers to a request.
//
// INVITE and MESSAGE ones are handled in the function 'NewRequestBody'.
func AddMethodHeaders(req *sip.Request, method sip.RequestMethod) bool {
	if req == nil {
		output.PrintFrameworkError(errMsgRequiredParam, "param", "req")

		return false
	}
	switch method {
	case sip.OPTIONS:
		req.AppendHeader(sip.NewHeader("Accept", "application/sdp"))
	case sip.REGISTER:
		req.AppendHeader(sip.NewHeader("Expires", strconv.Itoa(DefaultExpiresHeader)))
	case sip.PRACK:
		req.AppendHeader(sip.NewHeader("RAck", fmt.Sprintf("%d %d INVITE", DefaultRackCSeq, DefaultRackInviteCSeq)))
	case sip.SUBSCRIBE:
		req.AppendHeader(sip.NewHeader("Event", "presence"))
		req.AppendHeader(sip.NewHeader("Accept", ContentTypePidf))
		req.AppendHeader(sip.NewHeader("Supported", "eventlist"))
	case sip.NOTIFY:
		req.AppendHeader(sip.NewHeader("Subscription-State", "active;expires=3500"))
		req.AppendHeader(sip.NewHeader("Event", "presence"))
	case sip.REFER:
		if req.CallID() == nil {
			output.PrintFrameworkError("CallID header is required for REFER")

			return false
		}
		req.AppendHeader(&sip.ReferToHeader{
			Address: sip.Uri{
				Host: req.Recipient.Host,
				Port: req.Recipient.Port,
				User: "carol",
				UriParams: sip.HeaderParams{
					"Replaces": req.CallID().Value() + `%3Bto-tag%3D54321%3Bfrom-tag%3D12345`,
				},
			},
		})
		req.AppendHeader(sip.NewHeader("Refer-Sub", "true"))
		req.AppendHeader(sip.NewHeader("Supported", "replaces"))
	case sip.PUBLISH:
		req.AppendHeader(sip.NewHeader("Expires", strconv.Itoa(DefaultExpiresHeader)))
		req.AppendHeader(sip.NewHeader("Event", "presence"))
	case sip.INVITE, sip.ACK, sip.CANCEL, sip.BYE, sip.INFO, sip.MESSAGE, sip.UPDATE:
	}

	return true
}

// Returns a valid body and content type header for the given method.
func NewRequestBody(method sip.RequestMethod) (string, string) {
	switch method {
	case sip.INVITE:
		return NewDefaultInviteBody("", "", ""), DefaultInviteContentType
	case sip.MESSAGE:
		return DefaultMessageBodyContent, DefaultMessageContentType
	case sip.INFO:
		return DefaultInfoBodyContent, DefaultInfoContentType
	case sip.PUBLISH:
		return NewDefaultPublishBody("", "", ""), ContentTypePidf
	case sip.NOTIFY:
		return NewDefaultNotifyBody("", "", "", ""), ContentTypePidf
	case sip.ACK, sip.CANCEL, sip.BYE, sip.REGISTER, sip.OPTIONS, sip.SUBSCRIBE, sip.REFER, sip.PRACK, sip.UPDATE:

		return "", ""
	}

	return "", ""
}

// Returns a default body for an INVITE request.
//
// Default parameters:
// - host: "randomIPv4()".
// - sessid: random digits.
// - sessver: random digits.
func NewDefaultInviteBody(host, sessid, sessver string) string {
	if host == "" {
		host = random.RandIPv4().String()
	}
	if sessid == "" {
		sessid = random.RandDigits(6)
	}
	if sessver == "" {
		sessver = random.RandDigits(6)
	}

	return fmt.Sprintf(`v=0
o=caller %s %s IN IP4 %s
s=-
c=IN IP4 %s
t=0 0
m=audio 5004 RTP/AVP 0
a=rtpmap:0 PCMU/8000`, sessid, sessver, host, host)
}

// Returns a default body for a PUBLISH request.
//
// Default parameters:
// - id: random letters.
// - status: "open".
// - entity: "sip:bob@randomIPv4():5060".
func NewDefaultPublishBody(id, status, entity string) string {
	if id == "" {
		id = random.RandLetters(6)
	}
	if status == "" {
		status = "open"
	}
	if entity == "" {
		entity = fmt.Sprintf("sip:bob@%s:5060", random.RandIPv4())
	}

	return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<presence xmlns="urn:ietf:params:xml:ns:pidf"
					entity="%s">
	<tuple id="%s">
		<status>
			<basic>%s</basic>
		</status>
	</tuple>
</presence>`, entity, id, status)
}

// Returns a default body for a NOTIFY request.
//
// Default parameters:
// - id: random letters.
// - status: "open".
// - contact: "sip:bob@randomIPv4():5060".
// - ts: current UTC time in RFC 3339 format.
func NewDefaultNotifyBody(id, status, contact, ts string) string {
	if id == "" {
		id = random.RandLetters(6)
	}
	if status == "" {
		status = "open"
	}
	if contact == "" {
		contact = fmt.Sprintf("sip:bob@%s:5060", random.RandIPv4())
	}
	if ts == "" {
		ts = time.Now().UTC().Format(time.RFC3339)
	}

	return fmt.Sprintf(`<?xml version="1.0"?>
<presence xmlns="urn:ietf:params:xml:ns:pidf">
	<tuple id="%s">
		<status>
			<basic>%s</basic>
		</status>
		<contact>%s</contact>
		<timestamp>%s</timestamp>
	</tuple>
</presence>`, id, status, contact, ts)
}

// Converts a string to 'TransportType'.
func ParseTransport(transport string) TransportType {
	switch strings.ToLower(transport) {
	case "udp":
		return UDP
	case "tcp":
		return TCP
	case "tls":
		return TLS
	default:
		return UNKNOWN
	}
}

// Converts a string to a 'sip.RequestMethod'.
func ParseMethod(method string) sip.RequestMethod {
	switch strings.ToLower(method) {
	case "options":
		return sip.OPTIONS
	case "invite":
		return sip.INVITE
	case "register":
		return sip.REGISTER
	case "ack":
		return sip.ACK
	case "bye":
		return sip.BYE
	case "cancel":
		return sip.CANCEL
	case "subscribe":
		return sip.SUBSCRIBE
	case "notify":
		return sip.NOTIFY
	case "publish":
		return sip.PUBLISH
	case "refer":
		return sip.REFER
	case "info":
		return sip.INFO
	case "message":
		return sip.MESSAGE
	case "prack":
		return sip.PRACK
	case "update":
		return sip.UPDATE
	default:
		return ""
	}
}
