Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization Proposal: Unifying KL Divergence Checks with Flag Mechanism #830

Open
songyuc opened this issue Feb 5, 2025 · 1 comment
Open

Comments

@songyuc
Copy link
Contributor

songyuc commented Feb 5, 2025

Current Issues:

  1. Code Duplication: KL divergence checks appear in both mini-batch and epoch loops
  2. Cognitive Overhead: Nested break statements create control flow complexity
  3. Suboptimal Stopping: Potential delayed termination when threshold is exceeded

Proposed Solution:

# Before modification
for epoch in range(args.update_epochs):
    for start in range(0, args.batch_size, args.minibatch_size):
        # ...
        if args.target_kl and approx_kl > args.target_kl:
            break
    if args.target_kl and approx_kl > args.target_kl:
        break

# After modification
early_stop = False
for epoch in range(args.update_epochs):
    if early_stop:
        break
    
    for start in range(0, args.batch_size, args.minibatch_size):
        # ...
        if args.target_kl and approx_kl > args.target_kl:
            early_stop = True
            break

Key Benefits:

  1. Single Control Point: Centralized early stopping logic
  2. Immediate Termination: Ensures full loop exit upon first threshold violation
  3. DRY Compliance: Eliminates duplicate condition checks
  4. Behavior Consistency: Matches original intention of PPO's early stopping

Implementation Details:

  1. Add early_stop flag variable
  2. Sequential check order:
    • Epoch loop precondition
    • Mini-batch level check sets flag
  3. Preserves existing algorithm semantics

Compatibility & Testing:

  1. Backward Compatibility:
    • Fully maintains original API behavior
    • No configuration changes required
  2. Test Cases:
    # Case 1: Threshold not triggered
    target_kl = 0.2
    approx_kl_sequence = [0.15, 0.18, 0.19]
    
    # Case 2: Threshold crossed at 2nd mini-batch
    target_kl = 0.1
    approx_kl_sequence = [0.05, 0.12, 0.08]  # Should trigger at 2nd iteration
  3. Validation Metrics:
    • Number of completed epochs
    • Final KL divergence value
    • Training time comparison

Performance Considerations:

  1. Memory: Negligible overhead (single boolean flag)
  2. Computation: Eliminates redundant KL checks
  3. Early Exit: Same computational savings as original implementation

Supplementary Recommendations:

  1. Diagnostic Logging:
    if early_stop:
        logger.info(f"Early stopping at epoch {epoch}: KL {approx_kl:.4f} > {args.target_kl}")
  2. Documentation Update:
    ## Early Stopping
    Training terminates immediately when approximate KL divergence exceeds 
    `target_kl` threshold, ensuring strict policy update constraints.

This proposal maintains algorithmic fidelity while improving code quality and runtime behavior. The change is minimally invasive but provides significant maintenance benefits. I'm available to prepare a PR with these changes if needed.

@StoneT2000
Copy link
Member

we should just do the for else loop that the ppo_fast.py code does

the suggestion here is a bit more verbose than that

songyuc added a commit to songyuc/ManiSkill that referenced this issue Feb 10, 2025
   - Replace dual break checks with Pythonic for-else structure
   - Improve code readability while maintaining original logic
   - Related to issue haosulab#830
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants