FROM pytorch/pytorch:latest

USER root

RUN apt-get update && apt-get install -y curl

RUN mkdir -p model-weights/ \
    && curl -o model-weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

RUN mkdir -p model-weights/ \
    && curl -o model-weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

RUN apt-get install -y git

RUN pip install "onnxruntime==1.14.1" "onnx==1.13.1" "git+https://github.com/facebookresearch/segment-anything.git#567662b" torch-model-archiver awscli

COPY handler*.py ./

COPY src/ ./src/

COPY requirements.txt ./requirements.txt

# Zip up the sam_serve module
RUN tar czf sam_serve.tar.gz -C src .

RUN torch-model-archiver --model-name sam_vit_h_encode --version 1.0.0 --serialized-file model-weights/sam_vit_h_4b8939.pth --handler handler_encode.py --extra-files sam_serve.tar.gz --requirements-file requirements.txt

COPY scripts ./scripts

RUN python scripts/export_onnx_model.py --checkpoint model-weights/sam_vit_h_4b8939.pth --model-type vit_h --output sam_vit_h_decode.onnx

RUN torch-model-archiver --model-name sam_vit_h_decode --version 1.0.0 --serialized-file sam_vit_h_decode.onnx --handler handler_decode.py --extra-files sam_serve.tar.gz --requirements-file requirements.txt

RUN mkdir -p /home/model-store \
    && mv *.mar /home/model-store/ \
    && rm *.onnx \
    && rm -rf model-weights

# .mar files copied to s3 in gh action
