Skip to main content
This example demonstrates using @context to train a model at worker startup and then serve inference requests through a @remote function.

Deployment

# deploy.py
from vastai import Deployment
from vastai.data.query import gpu_name, RTX_4090, RTX_5090

app = Deployment(name="train-mnist", webserver_url="https://alpha-server.vast.ai")


@app.context()
class MNISTModel:
    async def __aenter__(self):
        import torch
        import torch.nn as nn
        import torch.optim as optim
        from torchvision import datasets, transforms

        class CNN(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
                self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
                self.pool = nn.MaxPool2d(2)
                self.fc1 = nn.Linear(64 * 7 * 7, 128)
                self.fc2 = nn.Linear(128, 10)
                self.relu = nn.ReLU()

            def forward(self, x):
                x = self.pool(self.relu(self.conv1(x)))
                x = self.pool(self.relu(self.conv2(x)))
                x = x.view(-1, 64 * 7 * 7)
                x = self.relu(self.fc1(x))
                return self.fc2(x)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = CNN().to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        loss_fn = nn.CrossEntropyLoss()

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        train_data = datasets.MNIST("/tmp/mnist", train=True, download=True, transform=transform)
        loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)

        print("Training MNIST classifier...")
        model.train()
        for epoch in range(3):
            total_loss = 0.0
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                loss = loss_fn(model(images), labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print(f"  Epoch {epoch + 1}/3  loss={total_loss / len(loader):.4f}")

        model.eval()
        self.model = model
        self.device = device
        print("Training complete. Model ready for inference.")
        return self

    async def __aexit__(self, *exc):
        pass


@app.remote(benchmark_dataset=[{"pixel_values": [[0.0] * 28] * 28}])
async def infer(pixel_values: list[list[float]]) -> dict:
    """Classify a 28x28 grayscale MNIST image.

    Args:
        pixel_values: 28x28 nested list of floats (0.0=black, 1.0=white),
                      raw pixel intensities before normalization.

    Returns:
        dict with "digit" (predicted class) and "probability" (confidence).
    """
    import torch

    ctx = app.get_context(MNISTModel)

    tensor = torch.tensor(pixel_values, dtype=torch.float32)
    tensor = (tensor - 0.1307) / 0.3081
    tensor = tensor.unsqueeze(0).unsqueeze(0).to(ctx.device)

    with torch.no_grad():
        logits = ctx.model(tensor)
        probs = torch.softmax(logits, dim=1)
        prob, digit = probs.max(dim=1)

    return {"digit": digit.item(), "probability": prob.item()}


image = app.image("vastai/pytorch:@vastai-automatic-tag", 16)
image.venv("/venv/main")
image.require(gpu_name.in_([RTX_4090, RTX_5090]))
app.configure_autoscaling(min_load=100)
app.ensure_ready()

Client

# client.py
import asyncio
import random
from deploy import app, infer


async def main():
    from torchvision import datasets, transforms

    test_data = datasets.MNIST("/tmp/mnist", train=False, download=True, transform=transforms.ToTensor())
    idx = random.randint(0, len(test_data) - 1)
    image_tensor, true_label = test_data[idx]

    pixel_values = image_tensor.squeeze(0).tolist()

    result = await infer(pixel_values)
    print(f"True label:  {true_label}")
    print(f"Predicted:   {result['digit']}")
    print(f"Confidence:  {result['probability']:.4f}")


if __name__ == "__main__":
    asyncio.run(main())

What This Demonstrates

  • Using @context to train a model at worker startup
  • The context’s __aenter__ runs GPU-intensive work (training) before the worker enters “ready” state
  • Accessing the trained model in a @remote function via app.get_context()
  • Benchmark dataset with a representative zero-image input
  • Using image.venv() to point to an existing venv in the Docker image