#!/bin/bash

# NOTE: This is a lower level interface to the SSH Helper.
# Run `sm-ssh -h` for the high-level interface.

# Commands:
# connect-endpoint <endpoint_name> [<extra_ssh_args>]
# proxy-host <fqdn>
# stop-waiting
# run-command <command> <args...>

set -e
set -o pipefail

self=$(realpath "${BASH_SOURCE[0]}")
dir=$(dirname "$self")
source "$dir"/sm-helper-functions 2>/dev/null || source sm-helper-functions

echo "sm-local-ssh-inference: Starting in $dir"


COMMAND=$1

if [[ "$COMMAND" == "proxy-host" ]]; then
  SM_SSH_HOST_NAME="$2"
  SM_RESOURCE_TYPE="inference"

  _check_ssh_proxy_host_name "$SM_SSH_HOST_NAME" "$SM_RESOURCE_TYPE"
  _export_ssh_key_env_var "$SM_SSH_HOST_NAME"

  INSTANCE_ID=$(_generate_key_and_print_instance_id "$SM_SSH_HOST_NAME")

  sm-local-start-ssh --proxy-setup-only "${INSTANCE_ID}"

elif [[ "$COMMAND" == "connect-endpoint" ]]; then
  ENDPOINT_NAME=$2

  # shellcheck disable=SC2091
  INSTANCE_ID=$($(_python) <<EOF
import sagemaker; from sagemaker_ssh_helper.log import SSHLog;
import logging; logging.basicConfig(level=logging.INFO);
print(SSHLog().get_endpoint_ssm_instance_ids("$ENDPOINT_NAME", timeout_in_sec=300)[0])
EOF
  )

  shift
  shift

  sm-local-start-ssh "$INSTANCE_ID" \
      -R localhost:12345:localhost:12345 \
      -L localhost:12022:localhost:22 \
      $*

elif [[ "$COMMAND" == "stop-waiting" ]]; then

  $0 run-command sm-wait stop

elif [[ "$COMMAND" == "run-command" ]]; then

  shift
  ARGS=$*

  # shellcheck disable=SC2086
  ssh -4 -i ~/.ssh/sagemaker-ssh-gw -p 11022 root@localhost \
    -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
    $ARGS

else
    echo "ERROR: Unknown command: '$COMMAND'"
    exit 1
fi