1 import os
2 import torch
3 from torchvision import transforms
4 from data_pipe import get_data
5 from vgg import VGG_13
6 from resnet18 import ResNet18
7 import numpy as np
8 import cv2
9 from PIL import Image
10
11
12 class Infer(object):
13
14 def __init__(self):
15 self.model = ResNet18()
16 self.model.load_state_dict(torch.load("./models/model_65.pth"))
17 self.model.eval()
18 self.cls = {' 0': 0, ' 1': 1, ' 10': 2, ' 11': 3, ' 12': 4, ' 13': 5, ' 14': 6, ' 15': 7, ' 16': 8, ' 17': 9, ' 18': 10, ' 19': 11, ' 2': 12, ' 20': 13, ' 21': 14, ' 22': 15, ' 23': 16, ' 24': 17, ' 25': 18, ' 26': 19, ' 27': 20, ' 28': 21, ' 29': 22, ' 3': 23, ' 30': 24, ' 31': 25, ' 32': 26, ' 33': 27, ' 34': 28, ' 35': 29, ' 36': 30, ' 37': 31, ' 38': 32, ' 39': 33, ' 4': 34, ' 5': 35, ' 6': 36, ' 7': 37, ' 8': 38, ' 9': 39}
19 self.new_cls = dict(zip(self.cls.values(), self.cls.keys()))
20
21 def _infer(self, img_tensor):
22 with torch.no_grad():
23 result = self.model(img_tensor)
24 return result
25
26 def predict(self, path):
27 img_path_list = [os.path.join(path ,x) for x in os.listdir(path)]
28 transform = transforms.Compose([
29 transforms.Resize([224, 224]),
30 transforms.ToTensor(),
31 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
32 for img_path in img_path_list:
33 img = cv2.imread(img_path)
34 img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
35 img_tensor = transform(img)
36 img_tensor = img_tensor.reshape((1, 3, 224, 224))
37 result = self._infer(img_tensor)
38 _, preds = torch.max(result.data, dim = 1)
39 print(self.new_cls[preds.numpy()[0]].strip())
40
41
42 if __name__ == "__main__":
43 path = "./test_images"
44 Infer().predict(path)