From 6085ef93523ddd1479660996888b8a91341c3cbe Mon Sep 17 00:00:00 2001 From: Kenneth Date: Thu, 16 May 2024 22:51:38 +0100 Subject: [PATCH] tweak hyperparams --- hyperparams.py | 2 +- inference.py | 21 ++------------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/hyperparams.py b/hyperparams.py index 2333fd3..e9b67f2 100644 --- a/hyperparams.py +++ b/hyperparams.py @@ -2,4 +2,4 @@ FILTER_COUNT = 32 KERNEL_SIZE = 2 CROP_SIZE = 200 BATCH_SIZE = 4 -EPOCHS = 15 +EPOCHS = 20 diff --git a/inference.py b/inference.py index a062889..f0a86b8 100644 --- a/inference.py +++ b/inference.py @@ -1,27 +1,12 @@ -import modal import torch import numpy as np from PIL import Image from model import mai from augmentation import preprocess_validation -MODEL_NAME = "mai_20240424_180855_4" -image = modal.Image.debian_slim().pip_install( - "datasets==2.19.0", - "albumentations==1.4.4", - "numpy==1.26.4", - "torch==2.2.2", -) -app = modal.App("multilayer-authenticity-identifier", image=image) -volume = modal.Volume.from_name("model-store") -model_store_path = "/vol/models" - - -@app.function(timeout=5000, gpu="T4", volumes={model_store_path: volume}) def load_model_and_run_inference(img): - print(f"REMOTE: {img.shape}") - mai.load_state_dict(torch.load(f"{model_store_path}/{MODEL_NAME}")) + mai.load_state_dict(torch.load("mai")) mai.eval() img_batch = np.expand_dims(img, axis=0) img_batch = torch.tensor(img_batch) @@ -30,9 +15,7 @@ def load_model_and_run_inference(img): print(prediction) -@app.local_entrypoint() def main(): img = Image.open("test_images/dog.jpg") img = preprocess_validation(image=np.array(img))["image"] - print(img.shape) - load_model_and_run_inference.remote(img) + load_model_and_run_inference(img)