pyworker
Extension Guide
8min
creating your own pyworker can be complex and challenging, with many potential pitfalls if you need assistance with adding new pyworkers, please don't hesitate to contact us this guide walks you through adding new backends it is taken from the hello world worker's readme in the vast pyworker repository https //github com/vast ai/pyworker/ there is a hello world pyworker implantation under workers/hello world this pyworker is created for an llm model server that runs on port 5001 has two api endpoints /generate generates an full response to the prompt and sends a json response /generate stream streams a response one token at a time both of these endpoints take the same api json payload { "prompt" string, "max response tokens" number | null } we want the pyworker to also expose two endpoints, for each of the above endpoints structure all pyworkers should have two files └── workers └── hello world ├── init py ├── data types py # contains data types representing model api endpoints ├── server py # contains endpoint handlers ├── client py # a script to call an endpoint through the autoscaler └── test load py # script for load testing all of the classes follow strict type hinting it is recommended that you type hint all of your function this will allow your ide or vscode with pyright plugin to find any type errors in your implementation you can also install pyright with npm install pyright and run pyright in the root of the project to find any type errors data types py data classes representing the model api are defined here they must inherit from lib data types apipayload apipayload is an abstract class and you need to define several functions for it import dataclasses import random from typing import dict, any from transformers import autotokenizer # used to count tokens in a prompt import nltk # used to download a list of all words to generate a random prompt and benchmark the llm model from lib data types import apipayload nltk download("words") word list = nltk corpus words words() \#### you can use any tokenizer that fits your llm `openai gpt` is free to use and is a good fit for most llms tokenizer = autotokenizer from pretrained("openai community/openai gpt") @dataclasses dataclass class inputdata(apipayload) prompt str max response tokens int @classmethod def for test(cls) > "apipayload" """defines how create a payload for load testing""" prompt = " " join(random choices(word list, k=int(250))) return cls(prompt=prompt, max response tokens=300) def generate payload json(self) > dict\[str, any] """defines how to convert an apipayload to json that will be sent to model api""" return dataclasses asdict(self) def count workload(self) > float """defines how to calculate workload for a payload""" return len(tokenizer tokenize(self prompt)) @classmethod def from json msg(cls, json msg dict\[str, any]) > "inputdata" """ defines how to transform json data to authdata and payload type, in this case `inputdata` defined above represents the data sent to the model api authdata is data generated by autoscaler in order to authenticate payloads in this case, the transformation is simple and 1 1 that is not always the case see comfyui's pyworker for more complicated examples """ errors = {} for param in inspect signature(cls) parameters if param not in json msg errors\[param] = "missing parameter" if errors raise jsondataexception(errors) return cls( { k v for k, v in json msg items() if k in inspect signature(cls) parameters } ) server py for every model api endpoint you want to use, you must implement an endpointhandler this class handles incoming requests, processes them, sends them to the model api server, and finally returns an http response endpointhandler has several abstract functions that must be implemented here, we implement two, one for /generate , and one for /generate stream """ authdata is a dataclass that represents authentication data sent from autoscaler to client requesting a route when a user requests a route from autoscaler, see vast's serverless documentation for how routing and authdata work when a user receives a route for this pyworker, they'll call pyworkers api with the following json { auth data authdata, payload inputdata # defined above } """ from aiohttp import web from lib data types import endpointhandler, jsondataexception from lib server import start server from data types import inputdata \#### this class is the implementer for the '/generate' endpoint of model api @dataclasses dataclass class generatehandler(endpointhandler\[inputdata]) @property def endpoint(self) > str \# the api endpoint return "/generate" @classmethod def payload cls(cls) > type\[inputdata] """this function should just return apipayload subclass used by this handler""" return inputdata def generate payload json(self, payload inputdata) > dict\[str, any] """ defines how to convert `inputdata` defined above, to json data to be sent to the model api this function too is a simple dataclass > json, but can be more complicated, see comfyui for an example """ return dataclasses asdict(payload) def make benchmark payload(self) > inputdata """ defines how to generate an inputdata for benchmarking this needs to be defined in only one endpointhandler, the one passed to the backend as the benchmark handler here we use the for test() method on inputdata however, in some cases you might need to fine tune your inputdata used for benchmarking to closely resemble the average request users call the endpoint with in order to get best autoscaling performance """ return inputdata for test() async def generate client response( self, client request web request, model response clientresponse ) > union\[web response, web streamresponse] """ defines how to convert a model api response to a response to pyworker client """ = client request match model response status case 200 log debug("success") data = await model response json() return web json response(data=data) case code log debug("sending response error unknown code") return web response(status=code) we also handle generatestreamhandler for streaming responses it is identical to generatehandler , except for the endpoint name and how we create a web response, as it is a streaming response class generatestreamhandler(endpointhandler\[inputdata]) @property def endpoint(self) > str return "/generate stream" @classmethod def payload cls(cls) > type\[inputdata] return inputdata def generate payload json(self, payload inputdata) > dict\[str, any] return dataclasses asdict(payload) def make benchmark payload(self) > inputdata return inputdata for test() async def generate client response( self, client request web request, model response clientresponse ) > union\[web response, web streamresponse] match model response status case 200 log debug("streaming response ") res = web streamresponse() res content type = "text/event stream" await res prepare(client request) async for chunk in model response content await res write(chunk) await res write eof() log debug("done streaming response") return res case code log debug("sending response error unknown code") return web response(status=code) you can now instantiate a backend and use it to handle requests from lib backend import backend, logaction \#### the url and port of model api model server url = "http //0 0 0 0 5001" \#### this is the log line that is emitted once the server has started model server start log msg = "server has started" model server error log msgs = \[ "exception corrupted model file" # message in the logs indicating the unrecoverable error ] backend = backend( model server url=model server url, \# location of model log file model log file=os environ\["model log"], \# for some model backends that can only handle one request at a time, be sure to set this to false to \# let pyworker handling queueing requests allow parallel requests=true, \# give the backend an endpointhandler instance that is used for benchmarking \# number of benchmark run and number of words for a random benchmark run are given benchmark handler=generatehandler(benchmark runs=3, benchmark words=256), \# defines how to handle specific log messages see docstring of logaction for details log actions=\[ (logaction modelloaded, model server start log msg), (logaction info, '"message" "download'), \[ (logaction modelerror, error msg) for error msg in model server error log msgs ], ], ) \#### this is a simple ping handler for pyworker async def handle ping( web request) return web response(body="pong") \#### this is a handler for forwarding a health check to model api async def handle healthcheck( web request) healthcheck res = await backend session get("/healthcheck") return web response(body=healthcheck res content, status=healthcheck res status) routes = \[ web post("/generate", backend create handler(generatehandler())), web post("/generate stream", backend create handler(generatestreamhandler())), web get("/ping", handle ping), web get("/healthcheck", handle healthcheck), ] if name == " main " \# start server, called from start server sh start server(backend, routes) test load py here you can create a script that allows you test an endpoint group running instances with this pyworker from lib test harness import run from data types import inputdata worker endpoint = "/generate" if name == " main " run(inputdata for test(), worker endpoint) you can then run the following command from the root of this repo to load test endpoint group \#### sends 1000 requests at the rate of 0 5 requests per second python3 workers hello world test load n 1000 rps 0 5 k "$api key" e "$endpoint group name"