TRL + OpenEnv + Modal
Motivation
The world of RL Fine-Tuning (RFT) is rapidly evolving. I'm particularly excited about the work Verifiers and prime-intellect are doing. I'm particularly dismayed at the state of local development (especially on Apple silicon). While Verifiers makes it incredibly easy to spin up and eval on arbitrarily complex (gym style) interaction protocols (e.g., wordle), prime-rl still requires managing persistent infra to get started on fine-tuning your own models. This is frustrating. Sometimes I just want to run things without worrying about surprise cloud bills in a month's time (skill issue, I know).
My mission: setup a harness to run gym style RFT on serverless compute. None of this is particularly revolutionary. In the rest of the blog post I'll step through how I glued together existing tech + tutorials to achieve this goal.
-
RFT: we use TRL, pretty industry standard
-
Gym Env: PyTorch's new OpenEnv (Verifiers forks their own TRL wrapper, it doesn't integrate perfectly)
-
Serverless: Modal
Our Gym Env
We're going to adopt the dataset from TRL's DPO example: ultrafeedback. Our model's task will be to predict the score assigned to a completion from that dataset. This is a trivial toy example to get us started - it's also a happy coincidence I needed to do this for a different project (stay tuned for that blog post).
OpenEnv environment logic mostly lives in an Environment subclass:
class UltrafeedbackEnvironment(Environment):
def __init__(
self,
dataset_name: str = "trl-lib/ultrafeedback_binarized",
split: str = "train",
seed: int | None = None,
reward_fn: callable = None,
):
self.dataset_name = dataset_name
self.split = split
self.reward_fn = reward_fn
# Load and format dataset
self._load_dataset()
# Initialize state
self._state = State(episode_id=str(uuid4()), step_count=0)
self._current_sample: dict | None = NoneIn the init function, we just load and process our dataset
Let's quickly take a look at our system prompt and dataset processing:
SYSTEM_PROMPT = "You are presented with a user prompt and a response from an LLM assistant. Grade the response to the prompt as a number between 0 and 10. DO NOT provide an answer to the original human prompt. The answer should be formatted as <answer>{number}</answer>."
def _load_dataset(self):
"""Load and format the dataset."""
dataset = load_dataset(self.dataset_name, split=self.split)
system_message = {"role": "system", "content": SYSTEM_PROMPT}
new_samples = {"prompt": [], "answer": []}
for sample in dataset:
# Add chosen response
new_samples["prompt"].append([system_message] + sample["chosen"])
new_samples["answer"].append(sample["score_chosen"])
# Add rejected response
new_samples["prompt"].append([system_message] + sample["rejected"])
new_samples["answer"].append(sample["score_rejected"])
new_dataset = Dataset.from_dict(new_samples)
# Cast answer to string
new_dataset = new_dataset.cast_column("answer", Value("string"))Above, we convert the dataset from DPO format to simple prompt and answer column.
For the sake of this example, the "reward" will just be the MAE between the model's predicted
score and the human assigned score.
The reward calculation happens in the environment step method:
def step(self, action: UltrafeedbackAction) -> UltrafeedbackObservation:
if self._current_sample is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
self._state.step_count += 1
completion = action.completion
correct_answer_str = self._current_sample["answer"]
response = parse_numeric_answer(completion)
if response is None:
correctness_reward = 0.0
else:
correct_answer = float(correct_answer_str)
difference = abs(response - correct_answer)
correctness_reward = 1.0 - (difference / 10.0)
correctness_reward = max(0.0, correctness_reward)
format_reward = calculate_format_reward(completion)
total_reward = (correctness_reward * 1.0 + format_reward * 0.2) / 1.2
return UltrafeedbackObservation(
prompt=self._current_sample["prompt"],
correct_answer=correct_answer_str,
done=True,
reward=total_reward,
metadata={
"predicted_score": response,
"correct_score": correct_answer_str,
"correctness_reward": correctness_reward,
"format_reward": format_reward,
},
)Since this is a toy environment, the step function basically just simulates the steps
of supervised learning: it extracts the LLM's formatted answer and compares it to the ground
truth score.
This (plus some plumbing you can reference the OpenEnv docs for) is all you need to define your interaction model for RFT. Next we'll look at how we actually use this in the training loop.
TRL
In this section, we'll look at how we can easily plug our custom gym environment into
TRL's GRPOTrainer class. We'll start by instantiating our environment:
env = UltrafeedbackEnvironment()
# Hack: `GRPOTrainer` still expects a dataset argument, but we handle dataset logic in our env
train_dataset = Dataset.from_dict(
{
"prompt": [
"Blah blah this will be ignored"
]
* 64
}
)Next, we need to define our rollout_function - this is what tells TRL how to interact with our environment.
For GRPO variance reduction to work, we need to make sure each episode rollout
starts from the same initial prompt. This is why we call env.reset with a seed argument.
This same pattern easily extends to multi-step environments like wordle or gridworld.
Our rollout logic is split up into 2 functions:
rollout_episode: a helper function to execute a single episode of our environmentrollout_function: a wrapper that executes all the episodes needed to run GRPO training
This is a lot of code, so we'll step through it piece by piece below
Rollout Function Implementation
1vllm_client = VLLMClient()2tokenizer = AutoTokenizer.from_pretrained(model_name)3 4def rollout_episode(5 training_args: GRPOConfig,6 env: UltrafeedbackEnvironment,7 tokenizer: PreTrainedTokenizerBase,8 vllm_client: VLLMClient,9 seed: int | None = None,10 max_steps: int = 10,11) -> dict[str, list]:12 13 obs = env.reset(seed=seed)14 initial_prompt = obs.prompt15 16 prompt_ids: list[int] = []17 completion_ids: list[int] = []18 logprobs: list[float] = []19 raw_rewards: list[float] = []20 completion_texts: list[str] = []21 22 current_obs = obs23 step_count = 024 episode_reward = 0.025 26 while not current_obs.done and step_count < max_steps:27 response = vllm_client.chat(28 # For our dataset, `prompt` is a chat-format list of messages29 [current_obs.prompt],30 temperature=training_args.temperature,31 top_k=training_args.top_k or -1,32 max_tokens=training_args.max_completion_length or 16,33 )34 35 step_prompt_ids = response["prompt_ids"][0]36 step_completion_ids = response["completion_ids"][0]37 step_logprobs = response["logprobs"][0]38 39 completion_text = tokenizer.decode(40 step_completion_ids, skip_special_tokens=True41 )42 43 44 action = UltrafeedbackAction(completion=completion_text)45 current_obs = env.step(action)46 47 prompt_ids.extend(step_prompt_ids)48 completion_ids.extend(step_completion_ids)49 logprobs.extend(step_logprobs)50 completion_texts.append(completion_text)51 52 step_reward = float(current_obs.reward or 0.0)53 raw_rewards.append(step_reward)54 episode_reward += step_reward55 56 step_count += 157 58 return {59 "prompt_ids": prompt_ids,60 "completion_ids": completion_ids,61 "logprobs": logprobs,62 "raw_rewards": [episode_reward],63 "completion_text": completion_texts,64 "prompt": initial_prompt,65 "correct_answer": [getattr(obs, "correct_answer", "")]66 if hasattr(obs, "correct_answer")67 else [""],68 }69 70def rollout_func(71 prompts: list[str], args: GRPOConfig, processing_class: PreTrainedTokenizerBase72):73 nonlocal vllm_client74 assert vllm_client is not None75 76 episode_prompt_ids: list[list[int]] = []77 episode_completion_ids: list[list[int]] = []78 episode_logprobs: list[list[float]] = []79 correctness_rewards: list[float] = []80 81 for batch_idx in range(len(prompts)):82 group_seed = random.randint(0, 2**31 - 1)83 num_gens = args.num_generations or 184 for gen_idx in range(num_gens):85 rollout_result = rollout_episode(86 training_args=args,87 env=env,88 tokenizer=processing_class,89 vllm_client=vllm_client,90 seed=group_seed,91 max_steps=getattr(args, "max_episode_steps", 10),92 )93 94 episode_prompt_ids.append(rollout_result["prompt_ids"])95 episode_completion_ids.append(rollout_result["completion_ids"])96 episode_logprobs.append(rollout_result["logprobs"])97 correctness_rewards.append(rollout_result["raw_rewards"][0])98 99 return {100 "prompt_ids": episode_prompt_ids,101 "completion_ids": episode_completion_ids,102 "logprobs": episode_logprobs,103 "raw_rewards": correctness_rewards,104 }Okay, now that we have our rollout functions, we can hook it up to TRL! The following should look more familiar if you've ever followed a TRL tutorial before
def reward_from_env(
prompts: list[str], completions: list[str], **kwargs
) -> list[float]:
"""Extract environment rewards."""
env_rewards = kwargs.get("raw_rewards", [])
return (
[float(reward) for reward in env_rewards]
if env_rewards
else [0.0] * len(completions)
)
trainer = GRPOTrainer(
model=model_name,
reward_funcs=reward_from_env,
train_dataset=train_dataset,
rollout_func=rollout_func,
processing_class=tokenizer,
args=GRPOConfig(
output_dir=str(Path(f"/models/{experiment_name}")),
vllm_mode="server",
use_vllm=True,
num_train_epochs=1,
num_generations=4,
max_completion_length=128,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
logging_steps=1,
report_to="wandb",
temperature=0.8,
top_k=10,
),
)
trainer.train()Deploying on Modal
At this point, assuming you can run vLLM locally, you could run our training code on your local machine. Unfortunately for me, installing vLLM on Apple silicon is a pain in the ass, so in comes Modal to the rescue.
To run our training script on modal, we just need to define a modal.Image, grab 2 GPUs (one for vLLM inference, one for training), and off we go.
Modal supports native python breakpoints! You just need to run your modal function in interactive mode
(modal run -i app.py::function). This is a life saver if you need to debug your code or just want to understand
the flow of data in this example
app = modal.App(name="trl-openenv")
PYPROJECT_DEPENDENCIES = [
"datasets>=4.4.1",
"transformers>=4.57.1",
"trl[vllm]>=0.25.1",
"wandb>=0.22.3",
"weave",
]
image = (
modal.Image.debian_slim(python_version="3.12")
.uv_pip_install(*PYPROJECT_DEPENDENCIES)
# I have OpenEnv as a git subtree so Cursor can easily reference it, you could also install from Github
.add_local_dir("OpenEnv", remote_path="/repos/OpenEnv", copy=True)
.uv_pip_install("/repos/OpenEnv")
# Add ultrafeedback environment
.add_local_python_source("ultrafeedback", copy=True)
.add_local_dir("src", remote_path="/root/src")
)
def _train_grpo(model_name: str, experiment_name: str):
start_grpo_trainer(model_name, experiment_name)
@app.function(
image=image,
gpu="A10G:2",
timeout=60 * 60 * 24, # 24 hours
secrets=[modal.Secret.from_name("wandb-secret")],
volumes={str(MODELS_DIR): checkpoints_volume},
)
def train_grpo(model_name: str, experiment_name: str):
experiment_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
_train_grpo(model_name, f"{experiment_name}_{experiment_timestamp}")
In start_grpo_trainer, we just need to run vLLM as a subprocess, then run the training loop we defined above:
def start_grpo_trainer():
env_copy = os.environ.copy()
env_copy["CUDA_VISIBLE_DEVICES"] = "0" # Run serve vLLM process on GPU 0
# Start vllm-serve in the background
vllm_process = subprocess.Popen(
[
"trl",
"vllm-serve",
"--model",
model_name,
"--gpu-memory-utilization",
"0.8",
],
env=env_copy,
)
# Make sure to run training on our other GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
trainer = ...Training a model
Kicking off a training run is easy: modal run app.py::train_grpo --model-name Qwen/Qwen2.5-0.5B-Instruct
We can monitor our run on wandb and validate reward go up
Conclusion
In this blog post, we glued together TRL, OpenEnv, and Modal to run gym style RFT on serverless compute. The full code is available here.
This project emerged from a larger experiment I'm running in this repo - please ignore some of the slop in the other files for now :)
The main code pointers for this post are:
I hope you found this post helpful. Please feel free to reach out on X or in the repo if you have any questions or feedback.