39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
|
import modal
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
from model import mai
|
||
|
from augmentation import preprocess_validation
|
||
|
|
||
|
MODEL_NAME = "mai_20240424_180855_4"
|
||
|
|
||
|
image = modal.Image.debian_slim().pip_install(
|
||
|
"datasets==2.19.0",
|
||
|
"albumentations==1.4.4",
|
||
|
"numpy==1.26.4",
|
||
|
"torch==2.2.2",
|
||
|
)
|
||
|
app = modal.App("multilayer-authenticity-identifier", image=image)
|
||
|
volume = modal.Volume.from_name("model-store")
|
||
|
model_store_path = "/vol/models"
|
||
|
|
||
|
|
||
|
@app.function(timeout=5000, gpu="T4", volumes={model_store_path: volume})
|
||
|
def load_model_and_run_inference(img):
|
||
|
print(f"REMOTE: {img.shape}")
|
||
|
mai.load_state_dict(torch.load(f"{model_store_path}/{MODEL_NAME}"))
|
||
|
mai.eval()
|
||
|
img_batch = np.expand_dims(img, axis=0)
|
||
|
img_batch = torch.tensor(img_batch)
|
||
|
prediction = mai(img_batch)
|
||
|
prediction = torch.sigmoid(prediction)
|
||
|
print(prediction)
|
||
|
|
||
|
|
||
|
@app.local_entrypoint()
|
||
|
def main():
|
||
|
img = Image.open("test_images/dog.jpg")
|
||
|
img = preprocess_validation(image=np.array(img))["image"]
|
||
|
print(img.shape)
|
||
|
load_model_and_run_inference.remote(img)
|