#!/bin/bash

# Helper script to connect to remote managed instance with SSM and start SSH port forwarding
# Every time it generates a new SSH key at ~/.ssh/sagemaker-ssh-gw and transfers the public part
# to the instance via S3 by executing the remote SSM command

# Syntax:
# sm-connect-ssh-proxy [--silent-setup-only] <instance_id> <s3_ssh_authorized_keys_path> <extra_ssh_args>

set -e
set -o pipefail

self=$(realpath "${BASH_SOURCE[0]}")
dir=$(dirname "$self")
source "$dir"/sm-helper-functions 2>/dev/null || source sm-helper-functions

silent_setup_only="false"
if [[ "$1" == "--silent-setup-only" ]]; then
  shift
  silent_setup_only="true"
fi

INSTANCE_ID="$1"
SSH_AUTHORIZED_KEYS_PATH="$2"
shift
shift
PORT_FWD_ARGS=$*

s3_path_length=${#SSH_AUTHORIZED_KEYS_PATH}
s3_path_last_char=${SSH_AUTHORIZED_KEYS_PATH:s3_path_length-1:1}
[[ $s3_path_last_char != "/" ]] && SSH_AUTHORIZED_KEYS_PATH="$SSH_AUTHORIZED_KEYS_PATH/"

if [[ "$silent_setup_only" == "false" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Connecting to: $INSTANCE_ID"
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Extra args: $PORT_FWD_ARGS"
fi

instance_status=$(aws ssm describe-instance-information --filters Key=InstanceIds,Values="$INSTANCE_ID" --query 'InstanceInformationList[0].PingStatus' --output text)

if [[ "$silent_setup_only" == "false" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Instance status: $instance_status"
fi

if [[ "$instance_status" != "Online" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Error: Instance is offline."
  exit 1
fi

if [ -z "${SSH_KEY}" ]; then
  SSH_KEY=~/.ssh/sagemaker-ssh-gw
  if [[ "$silent_setup_only" == "false" ]]; then
    echo "$(date -Iseconds) sm-connect-ssh-proxy: Generating $SSH_KEY keypair with ECDSA and uploading public key to $SSH_AUTHORIZED_KEYS_PATH"
  fi
  echo 'yes' | ssh-keygen -t ecdsa -q -f "${SSH_KEY}" -N '' >/dev/null
fi

SSH_KEY_NAME=$(basename "${SSH_KEY}")
SSH_KEY_S3_PATH="${SSH_AUTHORIZED_KEYS_PATH}${SSH_KEY_NAME}"

if [[ "$silent_setup_only" == "false" ]]; then
  aws s3 cp "${SSH_KEY}.pub" "${SSH_KEY_S3_PATH}.pub"
else
  aws s3 cp "${SSH_KEY}.pub" "${SSH_KEY_S3_PATH}.pub" >/dev/null
fi

CURRENT_REGION=$(aws configure list | grep region | awk '{print $2}')
if [[ "$silent_setup_only" == "false" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Will use AWS Region: $CURRENT_REGION"
fi

AWS_CLI_VERSION=$(aws --version)
if [[ "$silent_setup_only" == "false" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: AWS CLI version (should be v2): $AWS_CLI_VERSION"
fi

# TODO: consider moving to start-session with AWS-StartNonInteractiveCommand

if [[ "$silent_setup_only" == "false" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Running SSM commands at region ${CURRENT_REGION} to copy public key to ${INSTANCE_ID}"
fi
if [[ "$SM_SSH_DEBUG" == "true" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Running SSM commands at region ${CURRENT_REGION} to copy public key to ${INSTANCE_ID}" >>/tmp/sm-ssh-debug.log
fi
send_command=$(aws ssm send-command \
    --region "${CURRENT_REGION}" \
    --instance-ids "${INSTANCE_ID}" \
    --document-name "AWS-RunShellScript" \
    --comment "Copy public key for SSH helper" \
    --timeout-seconds 30 \
    --parameters "commands=[
        'mkdir -p /etc/ssh/authorized_keys.d/',
        'aws s3 cp \"${SSH_KEY_S3_PATH}.pub\" /etc/ssh/authorized_keys.d/',
        'ls -la /etc/ssh/authorized_keys.d/',
        'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys',
        'ls -la /etc/ssh/authorized_keys'
      ]" \
    --no-cli-pager --no-paginate \
    --output json)

json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/'

send_command=$(echo "$send_command" | $(_python) -m json.tool)
command_id=$(echo "$send_command" | grep "CommandId" | sed -e "$json_value_regexp")
if [[ "$silent_setup_only" == "false" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Got command ID: $command_id"
fi

# Wait a little bit to prevent strange InvocationDoesNotExist error
sleep 5

for i in $(seq 1 15); do
  # Switch to unicode for AWS CLI to properly parse output
  export LC_CTYPE=en_US.UTF-8
  command_output=$(aws ssm get-command-invocation \
      --instance-id "${INSTANCE_ID}" \
      --command-id "${command_id}" \
      --no-cli-pager --no-paginate \
      --output json)
  command_output=$(echo "$command_output" | $(_python) -m json.tool)
  command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp")
  output_content=$(echo "$command_output" | grep '"StandardOutputContent":' | sed -e "$json_value_regexp")
  error_content=$(echo "$command_output" | grep '"StandardErrorContent":' | sed -e "$json_value_regexp")

  if [[ "$silent_setup_only" == "false" ]]; then
    echo "$(date -Iseconds) sm-connect-ssh-proxy: Command status: $command_status ($i)"
  fi
  if [[ "$command_status" != "Pending" && "$command_status" != "InProgress" ]]; then
    if [[ "$silent_setup_only" == "false" ]]; then
      echo "$(date -Iseconds) sm-connect-ssh-proxy: Command output: $output_content"
    fi
    if [[ "$error_content" != "" ]]; then
      if [[ "$silent_setup_only" == "false" ]]; then
        echo "$(date -Iseconds) sm-connect-ssh-proxy: Command error: $error_content"
      fi
    fi
    break
  fi
  sleep 1
done

if [[ "$command_status" != "Success" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: ERROR: Command didn't finish successfully in time. Check SSM logs and Command history for more details."
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Command status: $command_status. Command ID: $command_id. Region: ${CURRENT_REGION}"
  exit 2
fi

if [[ "$SM_SSH_DEBUG" == "true" ]]; then
  echo "$(date -Iseconds) sm-connect-ssh-proxy: Silent setup complete." >>/tmp/sm-ssh-debug.log
fi

if [[ "$silent_setup_only" == "true" ]]; then
  # Silent setup complete
  exit 0
fi

echo "$(date -Iseconds) sm-connect-ssh-proxy: Starting SSH over SSM proxy (interactive session)"

# We don't use AWS-StartPortForwardingSession feature of SSM here, because we need port forwarding in both directions
#  with -L and -R parameters of SSH. This is useful for forwarding the PyCharm license server, which needs -R option.
#  SSM allows only forwarding of ports from the server (equivalent to the -L option).
proxy_command="aws ssm start-session\
 --reason 'Local user started SageMaker SSH Helper'\
 --region '${CURRENT_REGION}'\
 --target '${INSTANCE_ID}'\
 --document-name AWS-StartSSHSession\
 --parameters portNumber=%p"

# shellcheck disable=SC2086
ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
  -o ProxyCommand="$proxy_command" \
  -o ConnectTimeout=120 \
  -o ServerAliveInterval=15 -o ServerAliveCountMax=8 \
  -o PasswordAuthentication=no \
  -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
  $PORT_FWD_ARGS "$INSTANCE_ID"
