TRL + OpenEnv + Modal

TRL + OpenEnv + Modal

Ben Glickenhaus

Motivation

The world of RL Fine-Tuning (RFT) is rapidly evolving. I'm particularly excited about the work Verifiers and prime-intellect are doing. I'm particularly dismayed at the state of local development (especially on Apple silicon). While Verifiers makes it incredibly easy to spin up and eval on arbitrarily complex (gym style) interaction protocols (e.g., wordle), prime-rl still requires managing persistent infra to get started on fine-tuning your own models. This is frustrating. Sometimes I just want to run things without worrying about surprise cloud bills in a month's time (skill issue, I know).


My mission: setup a harness to run gym style RFT on serverless compute. None of this is particularly revolutionary. In the rest of the blog post I'll step through how I glued together existing tech + tutorials to achieve this goal.

Our Gym Env

We're going to adopt the dataset from TRL's DPO example: ultrafeedback. Our model's task will be to predict the score assigned to a completion from that dataset. This is a trivial toy example to get us started - it's also a happy coincidence I needed to do this for a different project (stay tuned for that blog post).


OpenEnv environment logic mostly lives in an Environment subclass:

class UltrafeedbackEnvironment(Environment):
    def __init__(
        self,
        dataset_name: str = "trl-lib/ultrafeedback_binarized",
        split: str = "train",
        seed: int | None = None,
        reward_fn: callable = None,
    ):
        self.dataset_name = dataset_name
        self.split = split
        self.reward_fn = reward_fn
 
        # Load and format dataset
        self._load_dataset()
 
        # Initialize state
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._current_sample: dict | None = None

In the init function, we just load and process our dataset


Let's quickly take a look at our system prompt and dataset processing:

SYSTEM_PROMPT = "You are presented with a user prompt and a response from an LLM assistant. Grade the response to the prompt as a number between 0 and 10. DO NOT provide an answer to the original human prompt. The answer should be formatted as <answer>{number}</answer>."
 
def _load_dataset(self):
    """Load and format the dataset."""
    dataset = load_dataset(self.dataset_name, split=self.split)
    system_message = {"role": "system", "content": SYSTEM_PROMPT}
    new_samples = {"prompt": [], "answer": []}
 
    for sample in dataset:
        # Add chosen response
        new_samples["prompt"].append([system_message] + sample["chosen"])
        new_samples["answer"].append(sample["score_chosen"])
        # Add rejected response
        new_samples["prompt"].append([system_message] + sample["rejected"])
        new_samples["answer"].append(sample["score_rejected"])
 
    new_dataset = Dataset.from_dict(new_samples)
    # Cast answer to string
    new_dataset = new_dataset.cast_column("answer", Value("string"))

Above, we convert the dataset from DPO format to simple prompt and answer column. For the sake of this example, the "reward" will just be the MAE between the model's predicted score and the human assigned score.


The reward calculation happens in the environment step method:

def step(self, action: UltrafeedbackAction) -> UltrafeedbackObservation:
    if self._current_sample is None:
        raise RuntimeError("Environment not initialized. Call reset() first.")
 
    self._state.step_count += 1
 
    completion = action.completion
    correct_answer_str = self._current_sample["answer"]
 
    response = parse_numeric_answer(completion)
 
    if response is None:
        correctness_reward = 0.0
    else:
        correct_answer = float(correct_answer_str)
        difference = abs(response - correct_answer)
        correctness_reward = 1.0 - (difference / 10.0)
        correctness_reward = max(0.0, correctness_reward)
 
    format_reward = calculate_format_reward(completion)
    total_reward = (correctness_reward * 1.0 + format_reward * 0.2) / 1.2
 
    return UltrafeedbackObservation(
        prompt=self._current_sample["prompt"],
        correct_answer=correct_answer_str,
        done=True,
        reward=total_reward,
        metadata={
            "predicted_score": response,
            "correct_score": correct_answer_str,
            "correctness_reward": correctness_reward,
            "format_reward": format_reward,
        },
    )

Since this is a toy environment, the step function basically just simulates the steps of supervised learning: it extracts the LLM's formatted answer and compares it to the ground truth score.


This (plus some plumbing you can reference the OpenEnv docs for) is all you need to define your interaction model for RFT. Next we'll look at how we actually use this in the training loop.

TRL

In this section, we'll look at how we can easily plug our custom gym environment into TRL's GRPOTrainer class. We'll start by instantiating our environment:

env = UltrafeedbackEnvironment()
 
# Hack: `GRPOTrainer` still expects a dataset argument, but we handle dataset logic in our env
train_dataset = Dataset.from_dict(
{
    "prompt": [
        "Blah blah this will be ignored"
    ]
    * 64
}
)

Next, we need to define our rollout_function - this is what tells TRL how to interact with our environment.

Note

For GRPO variance reduction to work, we need to make sure each episode rollout starts from the same initial prompt. This is why we call env.reset with a seed argument. This same pattern easily extends to multi-step environments like wordle or gridworld.

Our rollout logic is split up into 2 functions:

  • rollout_episode: a helper function to execute a single episode of our environment
  • rollout_function: a wrapper that executes all the episodes needed to run GRPO training

This is a lot of code, so we'll step through it piece by piece below

Rollout Function Implementation

1vllm_client = VLLMClient()
2tokenizer = AutoTokenizer.from_pretrained(model_name)
3
4def rollout_episode(
5 training_args: GRPOConfig,
6 env: UltrafeedbackEnvironment,
7 tokenizer: PreTrainedTokenizerBase,
8 vllm_client: VLLMClient,
9 seed: int | None = None,
10 max_steps: int = 10,
11) -> dict[str, list]:
12
13 obs = env.reset(seed=seed)
14 initial_prompt = obs.prompt
15
16 prompt_ids: list[int] = []
17 completion_ids: list[int] = []
18 logprobs: list[float] = []
19 raw_rewards: list[float] = []
20 completion_texts: list[str] = []
21
22 current_obs = obs
23 step_count = 0
24 episode_reward = 0.0
25
26 while not current_obs.done and step_count < max_steps:
27 response = vllm_client.chat(
28 # For our dataset, `prompt` is a chat-format list of messages
29 [current_obs.prompt],
30 temperature=training_args.temperature,
31 top_k=training_args.top_k or -1,
32 max_tokens=training_args.max_completion_length or 16,
33 )
34
35 step_prompt_ids = response["prompt_ids"][0]
36 step_completion_ids = response["completion_ids"][0]
37 step_logprobs = response["logprobs"][0]
38
39 completion_text = tokenizer.decode(
40 step_completion_ids, skip_special_tokens=True
41 )
42
43
44 action = UltrafeedbackAction(completion=completion_text)
45 current_obs = env.step(action)
46
47 prompt_ids.extend(step_prompt_ids)
48 completion_ids.extend(step_completion_ids)
49 logprobs.extend(step_logprobs)
50 completion_texts.append(completion_text)
51
52 step_reward = float(current_obs.reward or 0.0)
53 raw_rewards.append(step_reward)
54 episode_reward += step_reward
55
56 step_count += 1
57
58 return {
59 "prompt_ids": prompt_ids,
60 "completion_ids": completion_ids,
61 "logprobs": logprobs,
62 "raw_rewards": [episode_reward],
63 "completion_text": completion_texts,
64 "prompt": initial_prompt,
65 "correct_answer": [getattr(obs, "correct_answer", "")]
66 if hasattr(obs, "correct_answer")
67 else [""],
68 }
69
70def rollout_func(
71 prompts: list[str], args: GRPOConfig, processing_class: PreTrainedTokenizerBase
72):
73 nonlocal vllm_client
74 assert vllm_client is not None
75
76 episode_prompt_ids: list[list[int]] = []
77 episode_completion_ids: list[list[int]] = []
78 episode_logprobs: list[list[float]] = []
79 correctness_rewards: list[float] = []
80
81 for batch_idx in range(len(prompts)):
82 group_seed = random.randint(0, 2**31 - 1)
83 num_gens = args.num_generations or 1
84 for gen_idx in range(num_gens):
85 rollout_result = rollout_episode(
86 training_args=args,
87 env=env,
88 tokenizer=processing_class,
89 vllm_client=vllm_client,
90 seed=group_seed,
91 max_steps=getattr(args, "max_episode_steps", 10),
92 )
93
94 episode_prompt_ids.append(rollout_result["prompt_ids"])
95 episode_completion_ids.append(rollout_result["completion_ids"])
96 episode_logprobs.append(rollout_result["logprobs"])
97 correctness_rewards.append(rollout_result["raw_rewards"][0])
98
99 return {
100 "prompt_ids": episode_prompt_ids,
101 "completion_ids": episode_completion_ids,
102 "logprobs": episode_logprobs,
103 "raw_rewards": correctness_rewards,
104 }
Capture vllm client for the closure
TRL is very opinionated about what args it passes to rollout_func, so we need to capture any extra args in a closure for our custom rollout_episode function
1 / 12
1 / 12

Okay, now that we have our rollout functions, we can hook it up to TRL! The following should look more familiar if you've ever followed a TRL tutorial before

def reward_from_env(
    prompts: list[str], completions: list[str], **kwargs
) -> list[float]:
    """Extract environment rewards."""
    env_rewards = kwargs.get("raw_rewards", [])
    return (
        [float(reward) for reward in env_rewards]
        if env_rewards
        else [0.0] * len(completions)
    )
 
trainer = GRPOTrainer(
    model=model_name,
    reward_funcs=reward_from_env,
    train_dataset=train_dataset,
    rollout_func=rollout_func,
    processing_class=tokenizer,
    args=GRPOConfig(
        output_dir=str(Path(f"/models/{experiment_name}")),
        vllm_mode="server",
        use_vllm=True,
        num_train_epochs=1,
        num_generations=4,
        max_completion_length=128,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=1,
        logging_steps=1,
        report_to="wandb",
        temperature=0.8,
        top_k=10,
    ),
)
 
trainer.train()

Deploying on Modal

At this point, assuming you can run vLLM locally, you could run our training code on your local machine. Unfortunately for me, installing vLLM on Apple silicon is a pain in the ass, so in comes Modal to the rescue.


To run our training script on modal, we just need to define a modal.Image, grab 2 GPUs (one for vLLM inference, one for training), and off we go.

Tip

Modal supports native python breakpoints! You just need to run your modal function in interactive mode (modal run -i app.py::function). This is a life saver if you need to debug your code or just want to understand the flow of data in this example


Modal docs

app = modal.App(name="trl-openenv")
 
PYPROJECT_DEPENDENCIES = [
    "datasets>=4.4.1",
    "transformers>=4.57.1",
    "trl[vllm]>=0.25.1",
    "wandb>=0.22.3",
    "weave",
]
 
image = (
    modal.Image.debian_slim(python_version="3.12")
    .uv_pip_install(*PYPROJECT_DEPENDENCIES)
    # I have OpenEnv as a git subtree so Cursor can easily reference it, you could also install from Github
    .add_local_dir("OpenEnv", remote_path="/repos/OpenEnv", copy=True)
    .uv_pip_install("/repos/OpenEnv")
    # Add ultrafeedback environment
    .add_local_python_source("ultrafeedback", copy=True)
    .add_local_dir("src", remote_path="/root/src")
)
 
def _train_grpo(model_name: str, experiment_name: str):
    start_grpo_trainer(model_name, experiment_name)
 
 
@app.function(
    image=image,
    gpu="A10G:2",
    timeout=60 * 60 * 24,  # 24 hours
    secrets=[modal.Secret.from_name("wandb-secret")],
    volumes={str(MODELS_DIR): checkpoints_volume},
)
def train_grpo(model_name: str, experiment_name: str):
    experiment_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    _train_grpo(model_name, f"{experiment_name}_{experiment_timestamp}")
 

In start_grpo_trainer, we just need to run vLLM as a subprocess, then run the training loop we defined above:

def start_grpo_trainer():
    env_copy = os.environ.copy()
    env_copy["CUDA_VISIBLE_DEVICES"] = "0"  # Run serve vLLM process on GPU 0
 
    # Start vllm-serve in the background
    vllm_process = subprocess.Popen(
        [
            "trl",
            "vllm-serve",
            "--model",
            model_name,
            "--gpu-memory-utilization",
            "0.8",
        ],
        env=env_copy,
    )
 
    # Make sure to run training on our other GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    trainer = ...

Training a model

Kicking off a training run is easy: modal run app.py::train_grpo --model-name Qwen/Qwen2.5-0.5B-Instruct


We can monitor our run on wandb and validate reward go up

Conclusion

In this blog post, we glued together TRL, OpenEnv, and Modal to run gym style RFT on serverless compute. The full code is available here.


This project emerged from a larger experiment I'm running in this repo - please ignore some of the slop in the other files for now :)

The main code pointers for this post are:


I hope you found this post helpful. Please feel free to reach out on X or in the repo if you have any questions or feedback.