Advanced Reinforcement Learning Framework
Production-grade reinforcement learning system implementing PPO, SAC, and multi-agent training with distributed computing support
System Architecture
A scalable reinforcement learning framework that supports multiple algorithms, distributed training, and complex environment simulation.
Core Components
1. Policy Network Architecture
class PolicyNetwork(nn.Module):
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.state_encoder = StateEncoder(
input_dim=config['state_dim'],
hidden_dims=config['encoder_dims']
)
self.policy_head = GaussianPolicyHead(
input_dim=config['encoder_dims'][-1],
action_dim=config['action_dim'],
log_std_bounds=(-20, 2)
)
self.value_head = ValueHead(
input_dim=config['encoder_dims'][-1]
)
def forward(
self,
state: torch.Tensor
) -> Tuple[Distribution, torch.Tensor]:
features = self.state_encoder(state)
action_dist = self.policy_head(features)
value = self.value_head(features)
return action_dist, value
class GaussianPolicyHead(nn.Module):
def __init__(
self,
input_dim: int,
action_dim: int,
log_std_bounds: Tuple[float, float]
):
super().__init__()
self.mean = nn.Linear(input_dim, action_dim)
self.log_std = nn.Parameter(torch.zeros(action_dim))
self.log_std_bounds = log_std_bounds
def forward(self, x: torch.Tensor) -> Distribution:
mean = self.mean(x)
log_std = torch.clamp(
self.log_std,
*self.log_std_bounds
)
return Normal(mean, log_std.exp())
2. PPO Implementation
class PPOTrainer:
def __init__(self, config: Dict[str, Any]):
self.policy = PolicyNetwork(config['policy_config'])
self.optimizer = torch.optim.Adam(
self.policy.parameters(),
lr=config['learning_rate']
)
self.clip_range = config['clip_range']
self.value_coef = config['value_coef']
self.entropy_coef = config['entropy_coef']
def compute_loss(
self,
batch: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
# Get current policy distributions and values
action_dist, values = self.policy(batch['states'])
# Compute policy ratio
log_probs = action_dist.log_prob(batch['actions'])
ratio = torch.exp(log_probs - batch['old_log_probs'])
# Compute policy loss with clipping
advantages = batch['advantages']
policy_loss1 = advantages * ratio
policy_loss2 = advantages * torch.clamp(
ratio,
1 - self.clip_range,
1 + self.clip_range
)
policy_loss = -torch.min(policy_loss1, policy_loss2).mean()
# Compute value loss
value_loss = F.mse_loss(values, batch['returns'])
# Compute entropy bonus
entropy_loss = -action_dist.entropy().mean()
return {
'policy_loss': policy_loss,
'value_loss': value_loss * self.value_coef,
'entropy_loss': entropy_loss * self.entropy_coef
}
3. Distributed Training
class DistributedTrainer:
def __init__(self, config: Dict[str, Any]):
self.num_workers = config['num_workers']
self.rollout_length = config['rollout_length']
self.workers = [
RolloutWorker.remote(config)
for _ in range(self.num_workers)
]
self.learner = PPOLearner.remote(config)
async def train(
self,
num_iterations: int
) -> List[Dict[str, float]]:
metrics = []
for _ in range(num_iterations):
# Collect rollouts in parallel
rollout_ids = [
worker.collect_rollout.remote()
for worker in self.workers
]
rollouts = await ray.get(rollout_ids)
# Update policy
batch = self._prepare_batch(rollouts)
update_metrics = await self.learner.update.remote(batch)
# Sync updated policy with workers
policy_state = await self.learner.get_policy_state.remote()
sync_ops = [
worker.sync_policy.remote(policy_state)
for worker in self.workers
]
await ray.get(sync_ops)
metrics.append(update_metrics)
return metrics
4. Environment Wrappers
class VectorizedEnv:
def __init__(self, config: Dict[str, Any]):
self.envs = [
gym.make(config['env_id'])
for _ in range(config['num_envs'])
]
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
@torch.no_grad()
def step(
self,
actions: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict]]:
results = [
env.step(action)
for env, action in zip(self.envs, actions)
]
states, rewards, dones, infos = zip(*results)
# Reset environments that are done
for i, done in enumerate(dones):
if done:
states[i] = self.envs[i].reset()
return (
np.stack(states),
np.stack(rewards),
np.stack(dones),
infos
)
5. Experience Replay
class PrioritizedReplayBuffer:
def __init__(self, capacity: int, alpha: float = 0.6):
self.capacity = capacity
self.alpha = alpha
self.buffer = []
self.priorities = np.zeros(capacity, dtype=np.float32)
self.position = 0
def add(
self,
state: np.ndarray,
action: np.ndarray,
reward: float,
next_state: np.ndarray,
done: bool
) -> None:
max_priority = self.priorities.max() if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append(
(state, action, reward, next_state, done)
)
else:
self.buffer[self.position] = (
state, action, reward, next_state, done
)
self.priorities[self.position] = max_priority
self.position = (self.position + 1) % self.capacity
def sample(
self,
batch_size: int,
beta: float = 0.4
) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray]:
if len(self.buffer) == self.capacity:
priorities = self.priorities
else:
priorities = self.priorities[:len(self.buffer)]
probs = priorities ** self.alpha
probs /= probs.sum()
indices = np.random.choice(
len(self.buffer),
batch_size,
p=probs
)
samples = [self.buffer[idx] for idx in indices]
weights = (len(self.buffer) * probs[indices]) ** (-beta)
weights /= weights.max()
batch = {
'states': np.stack([s[0] for s in samples]),
'actions': np.stack([s[1] for s in samples]),
'rewards': np.stack([s[2] for s in samples]),
'next_states': np.stack([s[3] for s in samples]),
'dones': np.stack([s[4] for s in samples])
}
return batch, indices, weights
Usage Example
# Initialize training system
config = {
'policy_config': {
'state_dim': 64,
'action_dim': 6,
'encoder_dims': [256, 256]
},
'training_config': {
'num_workers': 8,
'rollout_length': 2048,
'learning_rate': 3e-4,
'clip_range': 0.2,
'value_coef': 0.5,
'entropy_coef': 0.01
},
'env_config': {
'env_id': 'HalfCheetah-v2',
'num_envs': 16
}
}
trainer = DistributedTrainer(config)
# Train the agent
metrics = await trainer.train(num_iterations=1000)
# Evaluate and save results
evaluator = PolicyEvaluator(config)
results = await evaluator.evaluate(trainer.learner)