tweak hyperparams
This commit is contained in:
@@ -2,4 +2,4 @@ FILTER_COUNT = 32
|
||||
KERNEL_SIZE = 2
|
||||
CROP_SIZE = 200
|
||||
BATCH_SIZE = 4
|
||||
EPOCHS = 15
|
||||
EPOCHS = 20
|
||||
|
21
inference.py
21
inference.py
@@ -1,27 +1,12 @@
|
||||
import modal
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from model import mai
|
||||
from augmentation import preprocess_validation
|
||||
|
||||
MODEL_NAME = "mai_20240424_180855_4"
|
||||
|
||||
image = modal.Image.debian_slim().pip_install(
|
||||
"datasets==2.19.0",
|
||||
"albumentations==1.4.4",
|
||||
"numpy==1.26.4",
|
||||
"torch==2.2.2",
|
||||
)
|
||||
app = modal.App("multilayer-authenticity-identifier", image=image)
|
||||
volume = modal.Volume.from_name("model-store")
|
||||
model_store_path = "/vol/models"
|
||||
|
||||
|
||||
@app.function(timeout=5000, gpu="T4", volumes={model_store_path: volume})
|
||||
def load_model_and_run_inference(img):
|
||||
print(f"REMOTE: {img.shape}")
|
||||
mai.load_state_dict(torch.load(f"{model_store_path}/{MODEL_NAME}"))
|
||||
mai.load_state_dict(torch.load("mai"))
|
||||
mai.eval()
|
||||
img_batch = np.expand_dims(img, axis=0)
|
||||
img_batch = torch.tensor(img_batch)
|
||||
@@ -30,9 +15,7 @@ def load_model_and_run_inference(img):
|
||||
print(prediction)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
img = Image.open("test_images/dog.jpg")
|
||||
img = preprocess_validation(image=np.array(img))["image"]
|
||||
print(img.shape)
|
||||
load_model_and_run_inference.remote(img)
|
||||
load_model_and_run_inference(img)
|
||||
|
Reference in New Issue
Block a user