#!python

import os
import sys
import argparse
import h5py
import torch
import numpy as np
import PIL.Image as Image
import torch.nn as nn
from torchvision import transforms

def load_net(fname, net):
    with h5py.File(fname, 'r') as h5f:
        for k, v in net.state_dict().items():
            param = torch.from_numpy(np.asarray(h5f[k]))
            v.copy_(param)

def make_layers(cfg, in_channels = 3, batch_norm = False, dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size = 3, padding = d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

class CSRNet(nn.Module):
    def __init__(self):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat = [512, 512, 512, 256, 128, 64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        # load vgg16 weights on the frontend (transfer learning)

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)



def load_data(img_path):
    img = Image.open(img_path).convert('RGB')
    return img


def main():
    parser = argparse.ArgumentParser(description='Count the number of people on the image')
    parser.add_argument('filepath')
    args = parser.parse_args()
    model = CSRNet()
    checkpoint = torch.load('models/0model_best.pth.tar', map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    transform = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225]),
    ])
    img = transform(Image.open(args.filepath).convert('RGB'))
    output = model(img.unsqueeze(0))
    print("Predicted Count : ", int(output.detach().sum().numpy()))


if __name__ == "__main__":
    sys.exit(main())