feat: add code for finetuning moondream

This commit is contained in:
2024-12-08 17:33:37 +00:00
parent 4bf2c89bf8
commit 6b53cb0411
11 changed files with 261 additions and 0 deletions

21
resnet/inference.py Normal file
View File

@@ -0,0 +1,21 @@
import torch
import numpy as np
from PIL import Image
from model import mai
from augmentation import preprocess_validation
def load_model_and_run_inference(img):
mai.load_state_dict(torch.load("mai"))
mai.eval()
img_batch = np.expand_dims(img, axis=0)
img_batch = torch.tensor(img_batch)
prediction = mai(img_batch)
prediction = torch.sigmoid(prediction)
print(prediction)
def main():
img = Image.open("test_images/dog.jpg")
img = preprocess_validation(image=np.array(img))["image"]
load_model_and_run_inference(img)