When developing DistAttn, we discovered a better grad checkpointing strategy in the presence of
FlashAttention (FA). This is because FA does rematerialization inside its backward kernel, which makes recomputation redundant. More interestingly, this applies to any cases with FA.