Skip to main content
The Text Generation Inference serverless template can be used to infer LLMs on Vast GPU instances. This page documents required environment variables and endpoints to get started. A full PyWorker and Client implementation can be found here.

Environment Variables

  • HF_TOKEN(string): HuggingFace API token with read permissions, used to download gated models. Read more about HuggingFace tokens here.
  • MODEL_ID(string): ID of the model to be used for inference. Supported HuggingFace models are shown here.
Some models on HuggingFace require the user to accept the terms and conditions on their HuggingFace account before using. For such models, this must be done first before using it with a Vast template.

Install the Vast.ai SDK

Ensure you have the vastai-sdk pip packaged installed
pip install vastai-sdk

Ensure API key is set

Configure the environment variable VAST_API_KEY to contain your Vast.ai Serverless API key
export VAST_API_KEY=<your-api-key>

Using /generate/

Python
import asyncio
from vastai import Serverless

MAX_TOKENS = 128

async def main():
    async with Serverless() as client:
        endpoint = await client.get_endpoint(name="my-tgi-endpoint")

        prompt = "Who are you?"

        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": MAX_TOKENS,
                "temperature": 0.7,
                "return_full_text": False
            }
        }

        resp = await endpoint.request("/generate", payload, cost=MAX_TOKENS)

        print(resp["response"]["generated_text"])

if __name__ == "__main__":
    asyncio.run(main())

Using /generate_stream/

Python
import asyncio
from vastai import Serverless

MAX_TOKENS = 1024

def build_prompt(system_prompt: str, user_prompt: str) -> str:
    return (
        f"<<SYS>>\n{system_prompt.strip()}\n<</SYS>>\n\n"
        f"User: {user_prompt.strip()}\n"
        f"Assistant:"
    )

async def main():
    async with Serverless() as client:
        endpoint = await client.get_endpoint(name="my-tgi-endpoint")

        system_prompt = (
            "You are Qwen.\n"
            "You are to only speak in English.\n"
        )
        user_prompt = """
        Critically analyze the extent to which hotdogs are sandwiches.
        """

        prompt = build_prompt(system_prompt, user_prompt)

        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": MAX_TOKENS,
                "temperature": 0.7,
                "do_sample": True,
                "return_full_text": False,
            }
        }

        resp = await endpoint.request(
            "/generate_stream",
            payload,
            cost=MAX_TOKENS,
            stream=True,
        )
        stream = resp["response"]

        printed_answer = False
        async for event in stream:
            tok = (event.get("token") or {}).get("text")
            if tok:
                if not printed_answer:
                    printed_answer = True
                    print("Answer:\n", end="", flush=True)
                print(tok, end="", flush=True)

if __name__ == "__main__":
    asyncio.run(main())