vLLM with Ray Serve Lab Solution

Look here for the solutions to the vLLM lab.

import json
from typing import AsyncGenerator
import requests
from fastapi import BackgroundTasks
from starlette.requests import Request
from starlette.responses import StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from ray import serve

Core deployment definition

@serve.deployment(ray_actor_options={"num_gpus": 1})
class VLLMPredictDeployment:
    def __init__(self, **kwargs):
        # Refer to https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py for the full list of arguments.
        args = AsyncEngineArgs(**kwargs)
        self.engine = AsyncLLMEngine.from_engine_args(args)

    async def stream_results(self, results_generator) -> AsyncGenerator[bytes, None]:
        num_returned = 0
        async for request_output in results_generator:
            text_outputs = [output.text for output in request_output.outputs]
            assert len(text_outputs) == 1
            text_output = text_outputs[0][num_returned:]
            ret = {"text": text_output}
            yield (json.dumps(ret) + "\n").encode("utf-8")
            num_returned += len(text_output)

    async def may_abort_request(self, request_id) -> None:
        await self.engine.abort(request_id)

    async def __call__(self, request: Request) -> Response:
        # The request should be a JSON object with the following fields: prompt, stream (True/False), kwargs for vLLM `SamplingParams`
        
        request_dict = await request.json()
        prompt = request_dict.pop("prompt")
        stream = request_dict.pop("stream", False)
        sampling_params = SamplingParams(**request_dict)
        request_id = random_uuid()
        results_generator = self.engine.generate(prompt, sampling_params, request_id)
        if stream:
            background_tasks = BackgroundTasks()
            # Using background_taks to abort the the request
            # if the client disconnects.
            background_tasks.add_task(self.may_abort_request, request_id)
            return StreamingResponse(
                self.stream_results(results_generator), background=background_tasks
            )

        # Non-streaming case
        final_output = None
        async for request_output in results_generator:
            if await request.is_disconnected():
                # Abort the request if the client disconnects.
                await self.engine.abort(request_id)
                return Response(status_code=499)
            final_output = request_output

        assert final_output is not None
        prompt = final_output.prompt
        text_outputs = [prompt + output.text for output in final_output.outputs]
        ret = {"text": text_outputs}
        return Response(content=json.dumps(ret))

Our config for testing

model='facebook/opt-125m'
download_dir='/mnt/local_storage'

prompt = 'What is your favorite place to visit in San Francisco?'

Start application on Serve

deployment = VLLMPredictDeployment.bind(model=model, download_dir=download_dir)
serve.run(deployment, name='vllm')

Test and print output

sample_input = {"prompt": prompt, "stream": True}
output = requests.post("http://localhost:8000/", json=sample_input)
for line in output.iter_lines():
    print(line.decode("utf-8"))

Run multiple requests asynchronously

cities = ['Atlanta', 'Boston', 'Chicago', 'Vancouver', 'Montreal', 'Toronto', 'Frankfurt', 'Rome', 'Warsaw', 'Cairo', 'Dar Es Salaam', 'Gaborone']
prompts = [f'What is your favorite place to visit in {city}?' for city in cities]

def send(m):
    return requests.post("http://localhost:8000/", json={"prompt": m, "stream": True})

outputs = map(send, prompts)
for output in outputs:
    for line in output.iter_lines():
        print(line.decode("utf-8"))

Change code to get 200 tokens in responses

def send(m):
    return requests.post("http://localhost:8000/", json={"prompt": m, "stream": True, "max_tokens": 200})

outputs = map(send, prompts)
for output in outputs:
    for line in output.iter_lines():
        print(line.decode("utf-8"))
serve.shutdown()