remove modal glue
This commit is contained in:
18
train.py
18
train.py
@@ -15,16 +15,6 @@ TEST_SIZE = 0.1
|
||||
|
||||
datasets.logging.set_verbosity(datasets.logging.INFO)
|
||||
|
||||
image = modal.Image.debian_slim().pip_install(
|
||||
"datasets==2.19.0",
|
||||
"albumentations==1.4.4",
|
||||
"numpy==1.26.4",
|
||||
"torch==2.2.2",
|
||||
)
|
||||
app = modal.App("multilayer-authenticity-identifier", image=image)
|
||||
volume = modal.Volume.from_name("model-store")
|
||||
model_store_path = "/vol/models"
|
||||
|
||||
|
||||
def collate(batch):
|
||||
pixel_values = []
|
||||
@@ -133,7 +123,6 @@ def load_data():
|
||||
return training_loader, validation_loader, len(training_ds)
|
||||
|
||||
|
||||
@app.function(gpu="T4", timeout=86400, volumes={model_store_path: volume})
|
||||
def train():
|
||||
training_loader, validation_loader, sample_size = load_data()
|
||||
|
||||
@@ -208,11 +197,8 @@ def train():
|
||||
|
||||
if avg_validation_loss < best_vloss:
|
||||
best_vloss = avg_validation_loss
|
||||
model_path = f"{model_store_path}/mai_{timestamp}_{epoch}"
|
||||
model_path = f"model/mai_{timestamp}_{epoch}"
|
||||
torch.save(model.state_dict(), model_path)
|
||||
volume.commit()
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
train.remote()
|
||||
train()
|
||||
|
Reference in New Issue
Block a user