Inspired by the work: https://arxiv.org/abs/2510.01527 we have started experimenting with this idea of round-trip reinforcement learning for generating valid CIF files. The core principle is simple: a model trained to convert between two representations (e.g., CIF to JSON for crystal structures) should maintain information such that the conversion can be reversed accurately.
'RTRL' is just slightly modified GRPO. The policy model performs a single conversion, in this case: given a CIF file, it generates a JSON representation. We then evaluate quality by asking: "If we tried to convert this JSON back to the original CIF, how likely would we succeed?" A judge model computes the log probability of reconstructing the original CIF from the generated JSON. Higher probability indicates better information preservation, which means the conversion maintains fidelity.
Critically, the policy model only learns one direction (e.g., CIF → JSON), while the judge model evaluates the reverse direction (JSON → CIF) through probability computation, not actual generation. This creates a self-supervised learning signal without requiring human annotations or explicit round-trip generation (this is an important clarification that will be important later).
The system consists of three main components:
Policy Model: A trainable Qwen 2.5-3B model with LoRA adapters that generates conversions
Judge Model: A frozen Qwen 2.5-3B-AWQ model served via vLLM that scores round-trip consistency
Reference Model: An frozen model for KL divergence regularization
We deploy the entire system on Modal, which provides containerized deployment with automatic GPU provisioning. The system utilizes three H100 GPUs (80GB each), with one dedicated to policy training, and the other two serving the judge and reference models respectively.
For model serving, we use vLLM to host the judge and reference models. vLLM provides OpenAI-compatible API endpoints, making integration straightforward while delivering efficient batched inference through PagedAttention, (credits to the RTRL paper authors for this, it made our life much easier). FP8 KV cache quantization was required for this build in order to enable us to handle long sequences (up to 10K tokens) within our memory budget. Full cache resolution can be maintained so long as token sequences for a given pass don't exceed ~4000 tokens.
The training loop follows this sequence:
Sample batch of crystal structures from dataset (e.g., CIF files)
Policy generates conversions (CIF → JSON) using temperature sampling
Judge computes log probability of reverse conversion (JSON → CIF) without actually generating
Higher reverse probability = better information preservation = higher reward
Compute advantages and update policy
Repeat for next batch
The policy model is built on Qwen 2.5-3B, a 3 billion parameter language model with a 32,768 token context window. We configure the tokenizer with left-padding to properly handle batched generation in the decoder-only architecture.
Rather than fine-tuning the entire model, we're using Low-Rank Adaptation (LoRA) to make training tractable. Our LoRA configuration uses rank 32 with alpha 32, applying adapters to all attention and MLP layers. We set dropout to 0.05 to provide mild regularization without overly constraining the model's expressiveness (thanks Claude).
Several optimization techniques keep memory usage under control. Gradient checkpointing trades computation for memory by recomputing activations during the backward pass, enabling larger batch sizes. We enable TF32 precision for matrix operations on the H100s, providing a 3-4x speedup over FP32 with negligible accuracy impact. Finally, micro-batching allows us to process logically large batches while staying within memory constraints.
Model: Qwen/Qwen2.5-3B-Instruct-AWQ
AWQ 4-bit weight quantization (4x memory reduction)
FP8 KV cache quantization
Served via vLLM on dedicated GPUs
vLLM Configuration:
Key parameters:
Context length: 10,240 tokens (balances memory and data requirements)
GPU memory utilization: 55% (prevents OOM while maximizing throughput)
Eager mode: Ensures consistent memory usage
Single sequence processing: Reduces memory fragmentation
Understanding how the judge computes rewards is essential to grasping RTRL. The judge never actually generates the reverse conversion. Instead, it computes the probability of the target through teacher forcing.
Here's a concrete example. Suppose the policy converts a CIF to JSON:
The judge evaluates: "What's the probability of generating the original CIF from this JSON?" it doesn't generate anything. Instead, it scores the target sequence token by token:
Position 1: Given the JSON prompt, what's the probability the first token is "data"?
Judge outputs probability distribution over entire vocabulary (50,000+ tokens)
Judge assigned P("data") = 0.87 ← High! The JSON makes this predictable
Record: log(0.87) = -0.14
Position 2: Given JSON + "data", what's the probability the next token is "_"?
Judge assigns P("_") = 0.65 ← High! Standard CIF format
Record: log(0.65) = -0.43
Position 3: Given JSON + "data_", what's the probability the next token is "Fe"?
Judge assigns P("Fe") = 0.43 ← Medium. JSON specifies Fe2O3
Record: log(0.43) = -0.84
This continues for every token in the target CIF. The final reward is the average log probability across all tokens. We use log probabilities because they're numerically stable: log(a × b) = log(a) + log(b), so multiplying many small probabilities becomes adding their logs.
Why this works for learning:
When the policy generates good JSON that preserves information, the judge can confidently predict each CIF token. For example, if the JSON contains "lattice": {"a": 5.035}, the judge can assign high probability to the tokens "5", ".", "0", "3", "5" when reconstructing _cell_length_a 5.035.
Conversely, if the policy generates incomplete JSON like {"composition": "Fe2O3"} without lattice parameters, the judge must guess lattice values. These guesses have low probability, resulting in low reward. The policy learns to include this information to maximize reward.
Our final training reward of -0.65 means the judge assigned approximately 52% probability to each correct token on average (since exp(-0.65) ≈ 0.52). For comparison, random guessing would yield probabilities around 0.002% (1/50000), corresponding to log probability around -10.8. The policy learned to generate outputs that are 1000x more predictable than random.
GRPO is a simplified variant of PPO that:
Groups samples by input
Normalizes rewards within each group
Computes advantages relative to group mean
Updates policy to maximize expected advantage
The loss function is:
With optional KL penalty:
Temperature = 0.9: goal being to encourage exploration while maintaining reasonable outputs
Top-k filtering = 20
Nucleus sampling (top-p) = 0.9
Each input generates 2 samples, and we group 4 inputs together for batch processing, allowing within-group advantage normalization. Generation is capped at 768 tokens, which proved sufficient for most crystal structures in our dataset.
For training, we adopt a constant learning rate of 5e-5 without warmup or decay (more on this in a second). The effective batch size is 64 samples, achieved through 4 groups with 8 gradient accumulation steps. This large effective batch provides stable gradient estimates crucial for RL training while keeping memory usage manageable. We apply a small KL penalty (coefficient 0.01) to prevent the policy from diverging too far from the reference, and clip gradients at norm 0.5 to handle occasional large updates.
Memory management requires several coordinated techniques. The policy processes logprobs in micro-batches of 2 to prevent OOM during the backward pass. Gradient checkpointing is enabled throughout, and we use TF32 precision for computational efficiency. Finally, we configure the CUDA memory allocator with max_split_size_mb=64 to reduce memory fragmentation during long training runs.
After extensive experimentation, we adopted a constant learning rate of 5e-5 throughout training. This was more for testing so we could skip the warmup and see loss/ reward trends quicker.
Traditional cosine schedules with warmup typically spend 3-5% of training steps gradually ramping up the learning rate. In our initial experiments we only had about 170 steps of training data, this meant wasting 80+ steps on slow warmup before reaching meaningful performance gains.

Training converged after 165 steps, at which point the dataset had been fully consumed. The mean reward improved from -1.10 at initialization to -0.65 at the end of training, representing a 0.45 point improvement or roughly 41% better performance. The final loss reached -0.003, indicating near-optimal policy optimization with respect to the advantage signal.
To interpret these reward values, recall that rewards are negative log probabilities: reward = -log P(x|y). Higher (less negative) values indicate better reconstruction quality. The initial reward of -1.10 corresponds to approximately 33% reconstruction probability, while the final reward of -0.65 corresponds to 52% probability. This represents a substantial improvement in the model's ability to perform round-trip conversions accurately.
The training dynamics reveal several encouraging patterns. Improvement occurred consistently throughout the training run with no extended plateaus or regressions. Most tellingly, the best reward was achieved at the final training step, suggesting the model had not yet converged and could benefit from additional data or training time.
Variance in the reward signal decreased as training progressed, indicating the policy was becoming more consistent in its predictions. This is particularly important in RL settings where high variance can destabilize learning. Despite processing each sample only once, we observed no signs of overfitting. The steady improvement through the final steps suggests the model was learning generalizable patterns rather than memorizing specific training examples.
Overall performance was promising, of course our end goal is to a have model that is capable of generating diverse yet valid structures. The next step is to flip this process on its head, and move towards training a policy model to go from semantic descriptions -> CIF. This means building a 'description' : CIF dataset, and significantly extending the data volume & training time. More to come.
On this page
Round-Trip Reinforcement Learning Experiments explores a simple idea: train a model to convert crystallography data from CIF to JSON and then judge how well that JSON could rewrite the original CIF. The policy model, based on a 3B language model with LoRA adapters, performs the forward conversion (CIF → JSON). A separate judge model, kept fixed, evaluates how likely it is to recover the exact CIF from that JSON, by computing a reverse-probability score token by token without actually generating the CIF. This score provides a reward signal for training the policy. The setup uses three parts: the policy (the convertor), the judge (scores round-trips), and a reference model for regularization. Training runs on Modal with three GPUs, using vLLM for judge serving and a careful memory plan. The goal is to create a reliable, reversible representation and to extend the approach to descriptions that generate CIF files.