Documentation Index
Fetch the complete documentation index at: https://docs.vast.ai/llms.txt
Use this file to discover all available pages before exploring further.
This example demonstrates using @context to train a model at worker startup and then serve inference requests through a @remote function.
Deployment
# deploy.py
from vastai import Deployment
from vastai.data.query import gpu_name, RTX_4090, RTX_5090
app = Deployment(name="train-mnist")
@app.context()
class MNISTModel:
async def __aenter__(self):
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = self.relu(self.fc1(x))
return self.fc2(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_data = datasets.MNIST("/tmp/mnist", train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
print("Training MNIST classifier...")
model.train()
for epoch in range(3):
total_loss = 0.0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = loss_fn(model(images), labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f" Epoch {epoch + 1}/3 loss={total_loss / len(loader):.4f}")
model.eval()
self.model = model
self.device = device
print("Training complete. Model ready for inference.")
return self
async def __aexit__(self, *exc):
pass
@app.remote(benchmark_dataset=[{"pixel_values": [[0.0] * 28] * 28}])
async def infer(pixel_values: list[list[float]]) -> dict:
"""Classify a 28x28 grayscale MNIST image.
Args:
pixel_values: 28x28 nested list of floats (0.0=black, 1.0=white),
raw pixel intensities before normalization.
Returns:
dict with "digit" (predicted class) and "probability" (confidence).
"""
import torch
ctx = app.get_context(MNISTModel)
tensor = torch.tensor(pixel_values, dtype=torch.float32)
tensor = (tensor - 0.1307) / 0.3081
tensor = tensor.unsqueeze(0).unsqueeze(0).to(ctx.device)
with torch.no_grad():
logits = ctx.model(tensor)
probs = torch.softmax(logits, dim=1)
prob, digit = probs.max(dim=1)
return {"digit": digit.item(), "probability": prob.item()}
image = app.image("vastai/pytorch:@vastai-automatic-tag", 16)
image.venv("/venv/main")
image.require(gpu_name.in_([RTX_4090, RTX_5090]))
app.configure_autoscaling(min_load=100)
app.ensure_ready()
Client
# client.py
import asyncio
import random
from deploy import app, infer
async def main():
from torchvision import datasets, transforms
test_data = datasets.MNIST("/tmp/mnist", train=False, download=True, transform=transforms.ToTensor())
idx = random.randint(0, len(test_data) - 1)
image_tensor, true_label = test_data[idx]
pixel_values = image_tensor.squeeze(0).tolist()
result = await infer(pixel_values)
print(f"True label: {true_label}")
print(f"Predicted: {result['digit']}")
print(f"Confidence: {result['probability']:.4f}")
if __name__ == "__main__":
asyncio.run(main())
What This Demonstrates
- Using
@context to train a model at worker startup
- The context’s
__aenter__ runs GPU-intensive work (training) before the worker enters “ready” state
- Accessing the trained model in a
@remote function via app.get_context()
- Benchmark dataset with a representative zero-image input
- Using
image.venv() to point to an existing venv in the Docker image