Skip to main content
This guide walks you through the structure of a PyWorker. By the end, you will know all of the pieces of a PyWorker and be able to create your own.
Vast has pre-made templates with PyWorkers already built-in. Search here first to see if a supported template works for your use case.
This repo contains all the components of a PyWorker. Simply for pedagogical purposes, the workers/hello_world/ PyWorker is created for an LLM server with two API endpoints:
  1. /generate: generates a LLM response and sends a JSON response
  2. /generate_stream: streams a response one token at a time
Both of these endpoints take the same API JSON payload:
JSON
{
    "prompt": String,
    "max_response_tokens": Number | null
}

Structure

All PyWorkers have four files:
Text
.
└── workers
    └── hello_world
        ├── __init__.py # blank file
        ├── data_types.py # contains data types representing model API endpoints
        ├── server.py # contains endpoint handlers
        └── test_load.py # script for load testing

All of the classes follow strict type hinting. It is recommended that you type hint all of your functions. This will allow your IDE or VSCode with pyright plugin to find any type errors in your implementation. You can also install pyright with npm install pyright and run pyright in the root of the project to find any type errors.

__init__.py

The __init__.pyfile is left blank. This tells the Python interpreter to treat the hello_world directory as a package. This allows us to import modules from within the directory.

data_types.py

This file defines how the PyWorker interacts with the ML model, and must adhere to the common framework laid out in lib/data_types.py. The file implements the specific request structure and payload handling that will be used in server.py. Data handling classes must inherit from lib.data_types.ApiPayload. ApiPayload is an abstract class that needs several functions defined for it. Below is an example implementation from the hello_world PyWorker that shows how to use the ApiPayload class.
Python
import dataclasses
import random
from typing import Dict, Any

from transformers import AutoTokenizer # used to count tokens in a prompt
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model

from lib.data_types import ApiPayload

nltk.download("words")
WORD_LIST = nltk.corpus.words.words()

#### you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")

@dataclasses.dataclass
class InputData(ApiPayload):
    prompt: str
    max_response_tokens: int

    @classmethod
    def for_test(cls) -> "ApiPayload":
        """defines how create a payload for load testing"""
        prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
        return cls(prompt=prompt, max_response_tokens=300)

    def generate_payload_json(self) -> Dict[str, Any]:
        """defines how to convert an ApiPayload to JSON that will be sent to model API"""
        return dataclasses.asdict(self)

    def count_workload(self) -> float:
        """defines how to calculate workload for a payload"""
        return len(tokenizer.tokenize(self.prompt))

    @classmethod
    def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
        """
        defines how to transform JSON data to AuthData and payload type,
        in this case `InputData` defined above represents the data sent to the model API.
        AuthData is data generated by the serverless system in order to authenticate payloads.
        In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
        for more complicated examples
        """
        errors = {}
        for param in inspect.signature(cls).parameters:
            if param not in json_msg:
                errors[param] = "missing parameter"
        if errors:
            raise JsonDataException(errors)
        return cls(
            **{
                k: v
                for k, v in json_msg.items()
                if k in inspect.signature(cls).parameters
            }
        )

Your specific use case could require additional classes or methods. Reference the TGI worker as another example.

server.py

For every ML model API endpoint you want to use, you must implement an EndpointHandler. This class handles incoming requests, processes them, sends them to the model API server, and finally returns an HTTP response with the model’s results. EndpointHandler has several abstract functions that must be implemented. Here, we implement the /generate endpoint functionality for the PyWorker by creating the GenerateHandler class that inherits from EndpointHandler.
EndpointHandler class allows the PyWorker and Model Server to communicate.
Python

"""
AuthData is a dataclass that represents Authentication data sent from the serverless system to the client requesting a route.
When a user requests a route, see Vast's Serverless documentation for how routing and AuthData
work.
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
{
    auth_data: AuthData,
    payload : InputData # defined above
}
"""
from aiohttp import web

from lib.data_types import EndpointHandler, JsonDataException
from lib.server import start_server
from .data_types import InputData

#### This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):

    @property
    def endpoint(self) -> str:
        # the API endpoint
        return "/generate"

    @classmethod
    def payload_cls(cls) -> Type[InputData]:
        """this function should just return ApiPayload subclass used by this handler"""
        return InputData

    def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
        """
        defines how to convert `InputData` defined above, to
        JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
        can be more complicated, See comfyui for an example
        """
        return dataclasses.asdict(payload)

    def make_benchmark_payload(self) -> InputData:
        """
        defines how to generate an InputData for benchmarking. This needs to be defined in only
        one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
        method on InputData. However, in some cases you might need to fine tune your InputData used for
        benchmarking to closely resemble the average request users call the endpoint with in order to get the best
        performance
        """
        return InputData.for_test()

    async def generate_client_response(
        self, client_request: web.Request, model_response: ClientResponse
    ) -> Union[web.Response, web.StreamResponse]:
        """
        defines how to convert a model API response to a response to PyWorker client
        """
        _ = client_request
        match model_response.status:
            case 200:
                log.debug("SUCCESS")
                data = await model_response.json()
                return web.json_response(data=data)
            case code:
                log.debug("SENDING RESPONSE: ERROR: unknown code")
                return web.Response(status=code)


We also handle GenerateStreamHandler for streaming responses. It is identical to GenerateHandler, except that this implementation creates a web response:
Python
class GenerateStreamHandler(EndpointHandler[InputData]):
    @property
    def endpoint(self) -> str:
        return "/generate_stream"

    @classmethod
    def payload_cls(cls) -> Type[InputData]:
        return InputData

    def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
        return dataclasses.asdict(payload)

    def make_benchmark_payload(self) -> InputData:
        return InputData.for_test()

    async def generate_client_response(
        self, client_request: web.Request, model_response: ClientResponse
    ) -> Union[web.Response, web.StreamResponse]:
        match model_response.status:
            case 200:
                log.debug("Streaming response...")
                res = web.StreamResponse()
                res.content_type = "text/event-stream"
                await res.prepare(client_request)
                async for chunk in model_response.content:
                    await res.write(chunk)
                await res.write_eof()
                log.debug("Done streaming response")
                return res
            case code:
                log.debug("SENDING RESPONSE: ERROR: unknown code")
                return web.Response(status=code)


You can now instantiate a Backend and use it to handle requests.
Python
from lib.backend import Backend, LogAction

#### the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"


#### This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
    "Exception: corrupted model file"  # message in the logs indicating the unrecoverable error
]

backend = Backend(
    model_server_url=MODEL_SERVER_URL,
    # location of model log file
    model_log_file=os.environ["MODEL_LOG"],
    # for some model backends that can only handle one request at a time, be sure to set this to False to
    # let PyWorker handling queueing requests.
    allow_parallel_requests=True,
    # give the backend an EndpointHandler instance that is used for benchmarking
    # number of benchmark run and number of words for a random benchmark run are given
    benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
    # defines how to handle specific log messages. See docstring of LogAction for details
    log_actions=[
        (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
        (LogAction.Info, '"message":"Download'),
        *[
            (LogAction.ModelError, error_msg)
            for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
        ],
    ],
)

#### this is a simple ping handler for PyWorker
async def handle_ping(_: web.Request):
    return web.Response(body="pong")

#### this is a handler for forwarding a health check to model API
async def handle_healthcheck(_: web.Request):
    healthcheck_res = await backend.session.get("/healthcheck")
    return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)

routes = [
    web.post("/generate", backend.create_handler(GenerateHandler())),
    web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
    web.get("/ping", handle_ping),
    web.get("/healthcheck", handle_healthcheck),
]

if __name__ == "__main__":
    # start server, called from start_server.sh
    start_server(backend, routes)
The full module is written in the server.py implementation of the hello_world PyWorker, as shown here:
Python
"""
PyWorker works as a man-in-the-middle between the client and model API. It's function is:
1. receive request from client, update metrics such as workload of a request, number of pending requests, etc.
2a. transform the data and forward the transformed data to model API
2b. send updated metrics to autoscaler
3. transform response from model API(if needed) and forward the response to client

PyWorker forward requests to many model API endpoint. each endpoint must have an EndpointHandler. You can also
write function to just forward requests that don't generate anything with the model to model API without an
EndpointHandler. This is useful for endpoints such as healthchecks. See below for example
"""

import os
import logging
import dataclasses
from typing import Dict, Any, Union, Type

from aiohttp import web, ClientResponse

from lib.backend import Backend, LogAction
from lib.data_types import EndpointHandler
from lib.server import start_server
from .data_types import InputData

# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"


# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "infer server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
    "Exception: corrupted model file"  # message in the logs indicating the unrecoverable error
]


logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s[%(levelname)-5s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__file__)


# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):

    @property
    def endpoint(self) -> str:
        # the API endpoint
        return "/generate"

    @classmethod
    def payload_cls(cls) -> Type[InputData]:
        return InputData

    def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
        """
        defines how to convert `InputData` defined above, to
        json data to be sent to the model API
        """
        return dataclasses.asdict(payload)

    def make_benchmark_payload(self) -> InputData:
        """
        defines how to generate an InputData for benchmarking. This needs to be defined in only
        one EndpointHandler, the one passed to the backend as the benchmark handler
        """
        return InputData.for_test()

    async def generate_client_response(
        self, client_request: web.Request, model_response: ClientResponse
    ) -> Union[web.Response, web.StreamResponse]:
        """
        defines how to convert a model API response to a response to PyWorker client
        """
        _ = client_request
        match model_response.status:
            case 200:
                log.debug("SUCCESS")
                data = await model_response.json()
                return web.json_response(data=data)
            case code:
                log.debug("SENDING RESPONSE: ERROR: unknown code")
                return web.Response(status=code)


# This is the same as GenerateHandler, except that it calls a streaming endpoint of the model API and streams the
# response, which itself is streaming, back to the client.
# it is nearly identical to handler as above, but it calls a different model API endpoint and it streams the
# streaming response from model API to client
class GenerateStreamHandler(EndpointHandler[InputData]):
    @property
    def endpoint(self) -> str:
        return "/generate_stream"

    @classmethod
    def payload_cls(cls) -> Type[InputData]:
        return InputData

    def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
        return dataclasses.asdict(payload)

    def make_benchmark_payload(self) -> InputData:
        return InputData.for_test()

    async def generate_client_response(
        self, client_request: web.Request, model_response: ClientResponse
    ) -> Union[web.Response, web.StreamResponse]:
        match model_response.status:
            case 200:
                log.debug("Streaming response...")
                res = web.StreamResponse()
                res.content_type = "text/event-stream"
                await res.prepare(client_request)
                async for chunk in model_response.content:
                    await res.write(chunk)
                await res.write_eof()
                log.debug("Done streaming response")
                return res
            case code:
                log.debug("SENDING RESPONSE: ERROR: unknown code")
                return web.Response(status=code)


# This is the backend instance of pyworker. Only one must be made which uses EndpointHandlers to process
# incoming requests
backend = Backend(
    model_server_url=MODEL_SERVER_URL,
    model_log_file=os.environ["MODEL_LOG"],
    allow_parallel_requests=True,
    # give the backend a handler instance that is used for benchmarking
    # number of benchmark run and number of words for a random benchmark run are given
    benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
    # defines how to handle specific log messages. See docstring of LogAction for details
    log_actions=[
        (LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
        (LogAction.Info, '"message":"Download'),
        *[
            (LogAction.ModelError, error_msg)
            for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
        ],
    ],
)


# this is a simple ping handler for pyworker
async def handle_ping(_: web.Request):
    return web.Response(body="pong")


# this is a handler for forwarding a health check to modelAPI
async def handle_healthcheck(_: web.Request):
    healthcheck_res = await backend.session.get("/healthcheck")
    return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)


routes = [
    web.post("/generate", backend.create_handler(GenerateHandler())),
    web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
    web.get("/ping", handle_ping),
    web.get("/healthcheck", handle_healthcheck),
]

if __name__ == "__main__":
    # start the PyWorker server
    start_server(backend, routes)

test_load.py

Once a Serverless Endpoint is setup with a {{Worker_Group}}, the test_load module lets us test the running instances:
Python
from lib.test_harness import run
from .data_types import InputData

WORKER_ENDPOINT = "/generate"

if __name__ == "__main__":
    run(InputData.for_test(), WORKER_ENDPOINT)
To run the script, provide the following parameters:
  • -n is the total number of requests to be send to the Endpoint
  • -rps is the rate (rate per second) at which the requests will be sent
  • -k is your Vast API key. You can define it in your environment or paste it into the command
  • -e is the name of the Serverless Endpoint
You can run the following command from the root of the PyWorker repo:
Text
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_NAME"
Be sure to define “API_KEY” and “ENDPOINT_NAME” in your environment before running, or replace these names with their actual values.
A successful test with n = 10 requests would look like the following. This test used 4 different GPU workers in the Worker Group for the 10 requests it was sent.

These are all the parts of a PyWorker! You will also find a client.py module in the worker folders of the repo. While it is not part of the PyWorker, Vast provides it as an example of how a user could interact with their model on the serverless system. The client.py file is not needed for the PyWorker to run on a GPU instance, and is intended to run on your local machine. The PyWorker Overview page shows more details.