#!python

# to package:
# python -m build --wheel

# to test-it locally
# pyton -m venv .
# source ./bin/activate
# pip install -I dist/csrnet-0.2-py3-none-any.whl
# csrnet pathtoimage.jpg

# to upload
#python -m twine check dist/*
# python -m twine upload dist/*


import os
import sys
import argparse
import h5py
import torch
import numpy as np
import PIL.Image as Image
import torch.nn as nn
from torchvision import transforms

def load_net(fname, net):
    with h5py.File(fname, 'r') as h5f:
        for k, v in net.state_dict().items():
            param = torch.from_numpy(np.asarray(h5f[k]))
            v.copy_(param)

def make_layers(cfg, in_channels = 3, batch_norm = False, dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size = 3, padding = d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

class CSRNet(nn.Module):
    def __init__(self):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat = [512, 512, 512, 256, 128, 64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        # load vgg16 weights on the frontend (transfer learning)

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)



def load_data(img_path):
    img = Image.open(img_path).convert('RGB')
    return img


def main():
    parser = argparse.ArgumentParser(description='Count the number of people on the image')
    parser.add_argument('filepath')
    args = parser.parse_args()
    model = CSRNet()
    #checkpoint = torch.load('models/0model_best.pth.tar', map_location=torch.device('cpu'))
    #model.load_state_dict(checkpoint['state_dict'])
    #np.savez_compressed("weights.zip", **checkpoint['state_dict'])
    checkpoint = np.load("models/weights.zip.npz")
    checkpoint = { layer: torch.from_numpy(mat) for layer,mat in checkpoint.items() }
    model.load_state_dict(checkpoint)
    transform = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225]),
    ])
    img = transform(Image.open(args.filepath).convert('RGB'))
    output = model(img.unsqueeze(0))
    print(int(output.detach().sum().numpy()))


if __name__ == "__main__":
    sys.exit(main())