Skip to content

Commit

Permalink
Fix: revised docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jbarnes850 committed Nov 18, 2024
1 parent 31b8bcf commit df5a259
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 121 deletions.
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ tail -f logs/training.log
## Documentation

- [MLX Distributed Documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
- [Performance Tuning Guide](docs/performance_tuning.md)
- [Setup Guide](docs/setup_guide.md)
- [Performance Tuning](docs/performance_tuning.md)
- [API Reference](docs/api.md)
- [Best Practices](docs/best_practices.md)

Expand All @@ -153,9 +153,10 @@ Our distributed training implementation follows MLX's recommended practices:

3. **Performance Optimization**:
- Mixed precision training
- Gradient compression for efficient communication
- Multiple TCP links for improved bandwidth
- Sliding window attention for long sequences
- Separate compute/memory streams
- Flash Attention implementation
- Grouped Query Attention (GQA)
- Optimized memory layout

4. **Monitoring and Recovery**:
- Real-time performance dashboard
Expand Down Expand Up @@ -190,7 +191,6 @@ For more details on MLX's distributed capabilities, see:
- Adjust number of worker processes

4. **Installation Issues**
- Update Xcode Command Line Tools
- Verify Python version compatibility
- Check MLX installation
- Review system requirements
Expand All @@ -201,6 +201,15 @@ For more detailed troubleshooting:
- Review [Performance Tuning Guide](docs/performance_tuning.md)
- Join our [Discord Community](https://discord.gg/mlx-distributed)

## Performance Tuning

For detailed information about our hardware configuration, training process, and performance optimizations, please see our [Performance Tuning Guide](docs/performance_tuning.md). This guide includes:
- Current hardware specifications and configurations
- Training time estimates and comparisons
- Detailed performance optimization strategies
- Memory management techniques
- Monitoring and stability measures

## Contributing

1. Fork the repository
Expand All @@ -211,9 +220,4 @@ For more detailed troubleshooting:

## License

MIT License

## Acknowledgments

- MLX Team at Apple
- MLX Community Contributors
MIT License
279 changes: 169 additions & 110 deletions docs/performance_tuning.md
Original file line number Diff line number Diff line change
@@ -1,141 +1,200 @@
# Performance Tuning Guide

## Memory Optimization

### 1. Gradient Checkpointing

## Technical Specifications

### Current Hardware Configuration
- **Primary Node**: Mac Studio M2 Ultra
- 24-core GPU
- 160GB unified memory limit
- Batch size: 32
- **Secondary Node**: MacBook M3 Max
- 16-core GPU
- 96GB unified memory limit
- Batch size: 16

### Training Process (In Progress)
- **Estimated Duration**:
- Best case: 3-4 weeks
- Realistic case: 4-6 weeks
- Assumes 24/7 operation
- For comparison: Similar models (1.3B parameters) take ~6 days on 8x A100 GPUs

### Performance Optimizations
- **Memory Management**:
- Dynamic batch size adjustment
- Gradient accumulation (16 steps)
- Memory defragmentation every 100 batches
- Gradient checkpointing on alternate layers

- **Distributed Training**:
- 4 TCP links for network communication
- Asynchronous data prefetching
- Optimized gradient synchronization
- Weight broadcasting optimization

- **Computation**:
- Mixed precision training
- Separate compute/memory streams
- Flash Attention implementation
- Grouped Query Attention (GQA)
- Optimized memory layout

### Training Characteristics
- Training is ~3x slower than inference due to:
- Gradient computation and synchronization
- Weight updates and broadcasting
- Memory management overhead
- Network communication latency

### Memory Considerations
- Primary bottleneck is memory bandwidth between devices
- Dynamic batch size adjustment based on memory usage
- Streaming data loading to manage memory pressure
- Gradient accumulation to handle memory constraints

### Monitoring and Stability
- Continuous performance monitoring
- Automatic batch size optimization
- Training stability checks
- Early stopping based on loss convergence
- Adaptive learning rate scheduling

These specifications represent our current optimized configuration for training a 1B parameter model on consumer Apple Silicon hardware. The training time estimates are based on empirical measurements and system monitoring data from our training implementation.


## Implementation Details

### Memory Management Implementation

#### 1. Dynamic Batch Size Adjustment
```python
from src.training.performance_utils import PerformanceOptimizer

# Configure checkpointing
optimizer = PerformanceOptimizer(config)
model = optimizer.setup_gradient_checkpointing(
model,
checkpoint_layers=[1, 3, 5, 7] # Checkpoint every other layer
)
def optimize_batch_size(self, current_memory_usage: float) -> int:
"""Dynamically adjust batch size based on memory usage"""
self.memory_history.append(current_memory_usage)

# Use moving average for stability
avg_memory = np.mean(self.memory_history[-10:])

# Adjust batch size
if avg_memory > self.config.target_memory_usage:
self.current_batch_size = max(
self.config.min_batch_size,
int(self.current_batch_size * 0.8)
)
elif avg_memory < self.config.target_memory_usage * 0.8:
self.current_batch_size = min(
self.config.max_batch_size,
int(self.current_batch_size * 1.2)
)
```

### 2. Dynamic Batch Sizing

#### 2. Gradient Accumulation
```python
# Monitor and adjust batch size
current_memory = mx.metal.get_active_memory() / (1024**3) # GB
new_batch_size = optimizer.optimize_batch_size(current_memory)
def update(self, current_batch_size: int, memory_usage: float) -> Dict[str, Any]:
"""Update accumulation steps"""
effective_batch_size = current_batch_size * self.current_steps

# Adjust based on memory and target batch size
if memory_usage > self.config.memory_threshold:
new_steps = min(self.config.max_steps, self.current_steps + 1)
elif effective_batch_size < self.config.target_batch_size:
new_steps = min(self.config.max_steps, self.current_steps + 1)
else:
new_steps = self.current_steps
```

### 3. Mixed Precision Training

#### 3. Memory Defragmentation
```python
# Enable mixed precision in config
config.training.mixed_precision = True
# In DistributedTrainer.train_epoch
if batch_idx % 100 == 0:
self.memory_manager.defragment()
```

## Compute Optimization

### 1. Stream Management
### Distributed Training Implementation

#### 1. Network Communication
```python
# Create separate streams for compute and memory ops
compute_stream = mx.Stream(mx.gpu)
memory_stream = mx.Stream(mx.cpu)
# Configure MPI parameters for optimal communication
if self.size > 1:
MPI.Info.Set("btl_tcp_links", "4")

with mx.stream(compute_stream):
# Compute operations
loss, grads = model.train_step(batch)
# Gradient synchronization
if self.world.size > 1:
grads = await self.network.sync_gradients(grads)

with mx.stream(memory_stream):
# Memory operations
next_batch = dataloader.prefetch()
# Weight synchronization
if self.config.sync_weights_every > 0 and self.step % self.config.sync_weights_every == 0:
self.model.parameters = await self.network.broadcast_weights(self.model.parameters)
```

### 2. Operation Fusion

MLX automatically fuses operations, but you can help by:

#### 2. Data Prefetching
```python
# Group related operations
def fused_forward(self, x):
# These operations will be fused
x = self.linear1(x)
x = mx.relu(x)
return self.linear2(x)
def start_prefetch(self):
"""Start background prefetching"""
def prefetch_worker():
try:
for batch in self.dataset:
processed = self.preprocess_function(batch)
self.prefetch_queue.put(processed)
except Exception as e:
self.logger.error(f"Prefetch error: {str(e)}")
```

### 3. Compute Scheduling
### Performance Monitoring Implementation

#### 1. Training Metrics
```python
# Profile-based scheduling
scheduler = ComputeScheduler(
compute_intensity=0.8, # Ratio of compute to memory ops
pipeline_depth=2 # Number of batches in flight
)
def check_training_health(self, loss: float, step: int) -> Dict[str, Any]:
"""Monitor training stability"""
metrics = {
"loss_std": np.std(self.loss_history[-100:]),
"loss_trend": self._calculate_trend(),
"gradient_norm": self._compute_gradient_norm(),
"learning_rate": self._get_current_lr(step)
}

if metrics["loss_std"] > 5.0:
self.logger.warning("High loss variance detected")
```

## Network Optimization

### 1. Gradient Synchronization

```python
# Optimize all-reduce operations
def optimized_all_reduce(grads):
# Combine small tensors
flat_grads = mx.concatenate([g.flatten() for g in grads])
# Single all-reduce
reduced = mx.distributed.all_sum(flat_grads)
# Reshape back
return [g.reshape(orig.shape) for g, orig in zip(
mx.split(reduced, [g.size for g in grads]),
grads
)]
```

### 2. Communication Overlap

#### 2. Memory Tracking
```python
# Overlap computation and communication
with mx.stream(compute_stream):
# Forward pass
loss = model(batch)

with mx.stream(comm_stream):
# Start gradient synchronization
grads = optimizer.reduce_gradients()
def _calculate_trend(self) -> float:
"""Calculate memory usage trend"""
if len(self.memory_history) < 10:
return 0.0

recent = self.memory_history[-10:]
return np.mean(recent) / mx.metal.get_memory_limit()
```

## Monitoring and Tuning
### Optimization Guidelines

### 1. Performance Metrics

```python
from src.monitoring.dashboard import PerformanceDashboard

dashboard = PerformanceDashboard()
dashboard.track_training_metrics(
throughput=tokens_per_second,
communication_time=sync_time,
cache_hit_rate=cache_hits/total_access
)
```

### 2. Memory Profiling

```python
# Monitor memory usage
memory_stats = dashboard.get_memory_stats()
print(f"Active Memory: {memory_stats['active_gb']:.2f} GB")
print(f"Peak Memory: {memory_stats['peak_gb']:.2f} GB")
```

### 3. Automatic Tuning

```python
# Enable autotuning
trainer.enable_autotuning(
target_memory_usage=0.85,
target_throughput=10000,
adaptation_rate=0.1
)
```
1. **Memory Management**
- Start with smaller batch sizes and gradually increase
- Monitor memory usage trends over time
- Use gradient accumulation when memory constrained
- Enable gradient checkpointing on alternate layers

2. **Network Communication**
- Use multiple TCP links for better bandwidth
- Implement asynchronous data prefetching
- Optimize gradient synchronization frequency
- Monitor network latency and adjust accordingly

3. **Training Stability**
- Track loss variance and gradient norms
- Implement early stopping with patience
- Use adaptive learning rates
- Monitor training metrics continuously

4. **Hardware Utilization**
- Balance workload across devices
- Monitor GPU utilization
- Optimize memory transfer patterns
- Use separate compute and memory streams

These implementations reflect our current optimized configuration for training large models on Apple Silicon hardware. Adjust parameters based on your specific hardware setup and training requirements.

## Best Practices

Expand Down

0 comments on commit df5a259

Please sign in to comment.