tweak hyperparams

This commit is contained in:
2024-05-16 22:51:38 +01:00
parent 9364d97bd7
commit 6085ef9352
2 changed files with 3 additions and 20 deletions

View File

@@ -2,4 +2,4 @@ FILTER_COUNT = 32
KERNEL_SIZE = 2 KERNEL_SIZE = 2
CROP_SIZE = 200 CROP_SIZE = 200
BATCH_SIZE = 4 BATCH_SIZE = 4
EPOCHS = 15 EPOCHS = 20

View File

@@ -1,27 +1,12 @@
import modal
import torch import torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from model import mai from model import mai
from augmentation import preprocess_validation from augmentation import preprocess_validation
MODEL_NAME = "mai_20240424_180855_4"
image = modal.Image.debian_slim().pip_install(
"datasets==2.19.0",
"albumentations==1.4.4",
"numpy==1.26.4",
"torch==2.2.2",
)
app = modal.App("multilayer-authenticity-identifier", image=image)
volume = modal.Volume.from_name("model-store")
model_store_path = "/vol/models"
@app.function(timeout=5000, gpu="T4", volumes={model_store_path: volume})
def load_model_and_run_inference(img): def load_model_and_run_inference(img):
print(f"REMOTE: {img.shape}") mai.load_state_dict(torch.load("mai"))
mai.load_state_dict(torch.load(f"{model_store_path}/{MODEL_NAME}"))
mai.eval() mai.eval()
img_batch = np.expand_dims(img, axis=0) img_batch = np.expand_dims(img, axis=0)
img_batch = torch.tensor(img_batch) img_batch = torch.tensor(img_batch)
@@ -30,9 +15,7 @@ def load_model_and_run_inference(img):
print(prediction) print(prediction)
@app.local_entrypoint()
def main(): def main():
img = Image.open("test_images/dog.jpg") img = Image.open("test_images/dog.jpg")
img = preprocess_validation(image=np.array(img))["image"] img = preprocess_validation(image=np.array(img))["image"]
print(img.shape) load_model_and_run_inference(img)
load_model_and_run_inference.remote(img)