Files
mai/inference.py

22 lines
561 B
Python
Raw Normal View History

2024-04-26 22:56:23 +01:00
import torch
import numpy as np
from PIL import Image
from model import mai
from augmentation import preprocess_validation
def load_model_and_run_inference(img):
2024-05-16 22:51:38 +01:00
mai.load_state_dict(torch.load("mai"))
2024-04-26 22:56:23 +01:00
mai.eval()
img_batch = np.expand_dims(img, axis=0)
img_batch = torch.tensor(img_batch)
prediction = mai(img_batch)
prediction = torch.sigmoid(prediction)
print(prediction)
def main():
img = Image.open("test_images/dog.jpg")
img = preprocess_validation(image=np.array(img))["image"]
2024-05-16 22:51:38 +01:00
load_model_and_run_inference(img)