vLLM with Ray Serve Lab

Try out cutting edge accelerated inference using vLLM in Ray Serve, along with a streaming request/response pattern for interacting with an LLM app.

vLLM accelerated deployment streamed batching lab

This lab is an opportunity to familiarize yourself with

If you’re new to LLM applications, focus on the intro level activity. If you’ve worked with LLM apps before, you may have time to try the additional activities.

Intro level activity

The main activity in this lab is to refactor the code below from Chen Shen and Cade Daniel’s work (referenced in this blog: https://www.anyscale.com/blog/continuous-batching-llm-inference) so that you can test it out in a notebook

Intermediate level activity

In the initial script, only one request is sent and streamed back.

vLLM really shines when we send lots of requests asynchronously – use the “cities” requests from the Hosting with Ray notebook to generate 12 requests, and send them asynchronously to the model deployment.

Stream the output

Advanced activity

In the demo script, we only get back 10 or so tokens for each request. Modify the code so that we get back 200 tokens.

Bonus!

  • Since we’ve focused on LLMs and Ray, we haven’t built HTTP-based “front ends” to our services. This example includes such a HTTP adapter. Examine the code and try to learn how it works – which pieces are “typical” HTTP processing boilerplate vs. which pieces are “Ray specific”

  • Note that the quality of responses from the simpler model in the demo script will not be great. If you have extra time, feel free to “upgrade” to a more powerful model. Hints: check the vLLM docs for compatible models and recall that chat-tuned models may need a more complex prompting pattern.

# code example from https://github.com/ray-project/ray/blob/cc983fc3e64c1ba215e981a43dd0119c03c74ff1/doc/source/serve/doc_code/vllm_example.py

import json
from typing import AsyncGenerator

from fastapi import BackgroundTasks
from starlette.requests import Request
from starlette.responses import StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from ray import serve


@serve.deployment(ray_actor_options={"num_gpus": 1})
class VLLMPredictDeployment:
    def __init__(self, **kwargs):
        """
        Construct a VLLM deployment.

        Refer to https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
        for the full list of arguments.

        Args:
            model: name or path of the huggingface model to use
            download_dir: directory to download and load the weights,
                default to the default cache dir of huggingface.
            use_np_weights: save a numpy copy of model weights for
                faster loading. This can increase the disk usage by up to 2x.
            use_dummy_weights: use dummy values for model weights.
            dtype: data type for model weights and activations.
                The "auto" option will use FP16 precision
                for FP32 and FP16 models, and BF16 precision.
                for BF16 models.
            seed: random seed.
            worker_use_ray: use Ray for distributed serving, will be
                automatically set when using more than 1 GPU
            pipeline_parallel_size: number of pipeline stages.
            tensor_parallel_size: number of tensor parallel replicas.
            block_size: token block size.
            swap_space: CPU swap space size (GiB) per GPU.
            gpu_memory_utilization: the percentage of GPU memory to be used for
                the model executor
            max_num_batched_tokens: maximum number of batched tokens per iteration
            max_num_seqs: maximum number of sequences per iteration.
            disable_log_stats: disable logging statistics.
            engine_use_ray: use Ray to start the LLM engine in a separate
                process as the server process.
            disable_log_requests: disable logging requests.
        """
        args = AsyncEngineArgs(**kwargs)
        self.engine = AsyncLLMEngine.from_engine_args(args)

    async def stream_results(self, results_generator) -> AsyncGenerator[bytes, None]:
        num_returned = 0
        async for request_output in results_generator:
            text_outputs = [output.text for output in request_output.outputs]
            assert len(text_outputs) == 1
            text_output = text_outputs[0][num_returned:]
            ret = {"text": text_output}
            yield (json.dumps(ret) + "\n").encode("utf-8")
            num_returned += len(text_output)

    async def may_abort_request(self, request_id) -> None:
        await self.engine.abort(request_id)

    async def __call__(self, request: Request) -> Response:
        """Generate completion for the request.

        The request should be a JSON object with the following fields:
        - prompt: the prompt to use for the generation.
        - stream: whether to stream the results or not.
        - other fields: the sampling parameters (See `SamplingParams` for details).
        """
        request_dict = await request.json()
        prompt = request_dict.pop("prompt")
        stream = request_dict.pop("stream", False)
        sampling_params = SamplingParams(**request_dict)
        request_id = random_uuid()
        results_generator = self.engine.generate(prompt, sampling_params, request_id)
        if stream:
            background_tasks = BackgroundTasks()
            # Using background_taks to abort the the request
            # if the client disconnects.
            background_tasks.add_task(self.may_abort_request, request_id)
            return StreamingResponse(
                self.stream_results(results_generator), background=background_tasks
            )

        # Non-streaming case
        final_output = None
        async for request_output in results_generator:
            if await request.is_disconnected():
                # Abort the request if the client disconnects.
                await self.engine.abort(request_id)
                return Response(status_code=499)
            final_output = request_output

        assert final_output is not None
        prompt = final_output.prompt
        text_outputs = [prompt + output.text for output in final_output.outputs]
        ret = {"text": text_outputs}
        return Response(content=json.dumps(ret))


def send_sample_request():
    import requests

    prompt = "How do I cook fried rice?"
    sample_input = {"prompt": prompt, "stream": True}
    output = requests.post("http://localhost:8000/", json=sample_input)
    for line in output.iter_lines():
        print(line.decode("utf-8"))


if __name__ == "__main__":
    # To run this example, you need to install vllm which requires
    # OS: Linux
    # Python: 3.8 or higher
    # CUDA: 11.0 – 11.8
    # GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
    # see https://vllm.readthedocs.io/en/latest/getting_started/installation.html
    # for more details.
    deployment = VLLMPredictDeployment.bind(model="facebook/opt-125m")
    serve.run(deployment)
    send_sample_request()