22 lines
561 B
Python
22 lines
561 B
Python
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from model import mai
|
|
from augmentation import preprocess_validation
|
|
|
|
|
|
def load_model_and_run_inference(img):
|
|
mai.load_state_dict(torch.load("mai"))
|
|
mai.eval()
|
|
img_batch = np.expand_dims(img, axis=0)
|
|
img_batch = torch.tensor(img_batch)
|
|
prediction = mai(img_batch)
|
|
prediction = torch.sigmoid(prediction)
|
|
print(prediction)
|
|
|
|
|
|
def main():
|
|
img = Image.open("test_images/dog.jpg")
|
|
img = preprocess_validation(image=np.array(img))["image"]
|
|
load_model_and_run_inference(img)
|