#!/bin/bash

# NOTE: This is a lower level interface to the SSH Helper.
# Run `sm-ssh -h` for the high-level interface.

# Syntax:
# sm-local-start-ssh [--proxy-setup-only] <instance_id> <extra_ssh_args>

set -e
set -o pipefail

self=$(realpath "${BASH_SOURCE[0]}")
dir=$(dirname "$self")
source "$dir"/sm-helper-functions 2>/dev/null || source sm-helper-functions

proxy_setup_only="false"
if [[ "$1" == "--proxy-setup-only" ]]; then
  shift
  proxy_setup_only="true"
fi

if [[ "$proxy_setup_only" == "false" ]]; then
  echo "sm-local-start-ssh: Starting in $dir"

  if [ -z "${SM_SSH_PYTHON}" ]; then
    echo "sm-local-start-ssh: WARNING: Using system Python. This can cause unexpected behavior. Please, use sm-ssh tool."
  else
    echo "sm-local-start-ssh: Using Python from ${SM_SSH_PYTHON}"
  fi
fi

if [[ "$SM_SSH_DEBUG" == "true" ]]; then
  echo "$(date -Iseconds) sm-local-start-ssh: Starting in $dir" >>/tmp/sm-ssh-debug.log
  echo "$(date -Iseconds) sm-local-start-ssh: Executed with args: $*" >>/tmp/sm-ssh-debug.log

  if [ -z "${SM_SSH_PYTHON}" ]; then
    echo "$(date -Iseconds) sm-local-start-ssh: WARNING: Using system Python. This can cause unexpected behavior. Please, use sm-ssh tool." >>/tmp/sm-ssh-debug.log
  else
    echo "$(date -Iseconds) sm-local-start-ssh: Using Python from ${SM_SSH_PYTHON}" >>/tmp/sm-ssh-debug.log
  fi
fi

INSTANCE_ID=$1
shift

INSTANCE_ID=$(echo "$INSTANCE_ID" | tr '\n' ' '  | tr -d ' ' | grep -oe 'mi-[0-9a-f]*$')

if [ -z "${INSTANCE_ID}" ]; then
  echo "INSTANCE_ID is not provided or incorrect (should be in the form mi-1234567890abcdef0)"
  exit 1
fi

if [[ "$proxy_setup_only" == "false" ]]; then
  echo "sm-local-start-ssh: Got instance ID: '$INSTANCE_ID'"
fi
if [[ "$SM_SSH_DEBUG" == "true" ]]; then
  echo "$(date -Iseconds) sm-local-start-ssh: Got instance ID: '$INSTANCE_ID'" >>/tmp/sm-ssh-debug.log
fi

# Useful for port forwarding and debugging, eg. passing -vvv option
EXTRA_SSH_ARGS=$*

if [[ "$proxy_setup_only" == "false" ]]; then
  echo "sm-local-start-ssh: Fetching bucket location"
fi

if [ -z "${SSH_AUTHORIZED_KEYS_PATH}" ]; then
  # shellcheck disable=SC2091
  bucket=$($(_python) <<EOF
import logging
logging.getLogger('sagemaker.config').setLevel(logging.WARNING)
logging.getLogger('botocore.credentials').setLevel(logging.WARNING)
import sagemaker
print(sagemaker.Session().default_bucket())
EOF
  )
  SSH_AUTHORIZED_KEYS_PATH="s3://$bucket/ssh-authorized-keys/"
fi

if [[ "$proxy_setup_only" == "false" ]]; then
  echo "SSH authorized keys S3 path -> ${SSH_AUTHORIZED_KEYS_PATH}"
  echo "NOTE: to override the default S3 path, run 'export SSH_AUTHORIZED_KEYS_PATH=s3://DOC-EXAMPLE-BUCKET/ssh-public-keys-jane-doe/' without quotes before attempting to connect."
fi


if [[ "$proxy_setup_only" == "false" ]]; then
  # shellcheck disable=SC2086  # extra args have to be unquoted to be parsed from the inner script
  sm-connect-ssh-proxy "${INSTANCE_ID}" \
      "${SSH_AUTHORIZED_KEYS_PATH}" \
      $EXTRA_SSH_ARGS
else
  if [[ "$SM_SSH_DEBUG" == "true" ]]; then
    # shellcheck disable=SC2086
    echo "$(date -Iseconds) sm-local-start-ssh: Setting up proxy with args: $EXTRA_SSH_ARGS" >>/tmp/sm-ssh-debug.log
    sm-connect-ssh-proxy --silent-setup-only "${INSTANCE_ID}" \
        "${SSH_AUTHORIZED_KEYS_PATH}" \
        $EXTRA_SSH_ARGS >>/tmp/sm-ssh-debug.log 2>&1
  else
    # shellcheck disable=SC2086
    sm-connect-ssh-proxy --silent-setup-only "${INSTANCE_ID}" \
        "${SSH_AUTHORIZED_KEYS_PATH}" \
        $EXTRA_SSH_ARGS
  fi

  CURRENT_REGION=$(aws configure list | grep region | awk '{print $2}')

  if [[ "$SM_SSH_DEBUG" == "true" ]]; then
    echo "$(date -Iseconds) sm-local-start-ssh: Starting SSM session in ${CURRENT_REGION}" >>/tmp/sm-ssh-debug.log
  fi
  aws ssm start-session\
   --reason 'Local user started SSH with SageMaker SSH Helper proxy'\
   --region "${CURRENT_REGION}"\
   --target "${INSTANCE_ID}"\
   --document-name AWS-StartSSHSession\
   --parameters portNumber=22
fi

if [[ "$SM_SSH_DEBUG" == "true" ]]; then
  echo "$(date -Iseconds) sm-local-start-ssh: Done." >>/tmp/sm-ssh-debug.log
fi