#!/usr/bin/env python3
import hashlib
import hmac
import json
import os
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
import requests
SLACK_WEBHOOK_URL = os.environ["SLACK_WEBHOOK_URL"]
VAST_WEBHOOK_SECRET = os.environ["VAST_WEBHOOK_SECRET"]
PORT = int(os.environ.get("PORT", "8787"))
MAX_SIGNATURE_AGE_SECONDS = 300
CONSOLE = os.environ.get("VAST_CONSOLE", "https://cloud.vast.ai")
ACTION_URLS = {
"low_credit": f"{CONSOLE}/billing/",
"billing_failed": f"{CONSOLE}/billing/",
"payment_receipt": f"{CONSOLE}/billing/",
"instance_created": f"{CONSOLE}/instances/",
"instance_started": f"{CONSOLE}/instances/",
"instance_stopped": f"{CONSOLE}/instances/",
"instance_offline": f"{CONSOLE}/instances/",
"instance_online": f"{CONSOLE}/instances/",
"outbid": f"{CONSOLE}/instances/",
"upcoming_downtime": f"{CONSOLE}/instances/",
"webhook_test": CONSOLE,
}
def verify_vast_signature(headers, raw_body: bytes) -> bool:
timestamp = headers.get("X-Vast-Timestamp", "")
signature = headers.get("X-Vast-Signature-256", "")
if not timestamp or not signature.startswith("sha256="):
return False
try:
age = abs(time.time() - int(timestamp))
except ValueError:
return False
if age > MAX_SIGNATURE_AGE_SECONDS:
return False
signed = timestamp.encode("utf-8") + b"." + raw_body
digest = hmac.new(
VAST_WEBHOOK_SECRET.encode("utf-8"),
signed,
hashlib.sha256,
).hexdigest()
return hmac.compare_digest(signature, f"sha256={digest}")
def slack_message(payload: dict) -> dict:
subject = payload.get("subject") or "Vast.ai notification"
message = payload.get("message") or json.dumps(payload, sort_keys=True)
notif_type = payload.get("notif_type") or "notification"
event_id = payload.get("event_id")
action_url = ACTION_URLS.get(notif_type, CONSOLE)
details = [f"type={notif_type}"]
if event_id:
details.append(f"event_id={event_id}")
return {
"text": f"{subject}\n{message}\n{action_url}\n{' '.join(details)}"
}
class Handler(BaseHTTPRequestHandler):
def _json(self, status: int, body: dict):
data = json.dumps(body).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def do_GET(self):
if self.path == "/health":
self._json(200, {"ok": True})
return
self._json(404, {"ok": False, "error": "not_found"})
def do_POST(self):
length = int(self.headers.get("Content-Length", "0") or "0")
raw_body = self.rfile.read(length)
if not verify_vast_signature(self.headers, raw_body):
self._json(401, {"ok": False, "error": "invalid_signature"})
return
try:
payload = json.loads(raw_body.decode("utf-8") or "{}")
except json.JSONDecodeError:
self._json(400, {"ok": False, "error": "invalid_json"})
return
if not isinstance(payload, dict):
self._json(400, {"ok": False, "error": "invalid_payload"})
return
response = requests.post(
SLACK_WEBHOOK_URL,
json=slack_message(payload),
timeout=10,
)
if response.status_code >= 400:
self._json(502, {"ok": False, "slack_status": response.status_code})
return
print(
f"forwarded notif_type={payload.get('notif_type')} "
f"event_id={payload.get('event_id')}",
flush=True,
)
self._json(200, {"ok": True})
def log_message(self, fmt, *args):
print(f"{self.address_string()} - {fmt % args}", flush=True)
if __name__ == "__main__":
print(f"listening on http://127.0.0.1:{PORT}", flush=True)
ThreadingHTTPServer(("127.0.0.1", PORT), Handler).serve_forever()