Skip to content

Commit

Permalink
fix gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Feb 26, 2025
1 parent a2e52a2 commit c011405
Show file tree
Hide file tree
Showing 4 changed files with 463 additions and 78 deletions.
108 changes: 30 additions & 78 deletions src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,73 +6,15 @@
import triton
import triton.language as tl


# Helper function for computing logsumexp
@triton.jit
def logsumexp_kernel(
logits_ptr,
output_ptr,
B,
S,
V, # batch size, seq len, vocab size
stride_b,
stride_s,
stride_v,
out_stride_b,
out_stride_s,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S

# Bounds check
if batch_idx >= B or seq_idx >= S:
return

# Pointers
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s

# Find maximum for numerical stability
max_val = -float("inf")
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size

logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))

# Compute sum of exp(logit - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size

logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)

# Compute logsumexp
result = max_val + tl.log(sum_exp)

# Store result
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
from .logsumexp import logsumexp_kernel


@triton.jit
def grad_softmax_kernel(
grad_student_logits_ptr,
student_logits_ptr,
target_token_ids_ptr,
teacher_probs_ptr,
student_probs_ptr,
mask_ptr,
B,
S,
Expand All @@ -82,15 +24,15 @@ def grad_softmax_kernel(
stride_gl_b,
stride_gl_s,
stride_gl_v,
stride_l_b,
stride_l_s,
stride_l_v,
stride_t_b,
stride_t_s,
stride_t_k,
stride_p_b,
stride_p_s,
stride_p_k,
stride_sp_b,
stride_sp_s,
stride_sp_k,
stride_m_b,
stride_m_s,
stride_m_k,
Expand All @@ -116,21 +58,31 @@ def grad_softmax_kernel(
teacher_probs_base = (
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
)
student_probs_base = (
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
)
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s

# Softmax over full vocab case
for k in range(0, K):
# Load token ID, teacher prob, and mask for this position
token_id = tl.load(token_ids_base + k * stride_t_k)
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
mask_val = tl.load(mask_base + k * stride_m_k)

# Apply mask by scaling gradient to zero if masked
grad_val = teacher_prob * scale * mask_val

# Update the gradient for this token's position in the vocabulary
# Only contributes if mask_val is non-zero
tl.atomic_add(grad_logits_base + token_id * stride_gl_v, grad_val)
for j in range(0, K):
other_token_id = tl.load(token_ids_base + j * stride_t_k)
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
mask_j = tl.load(mask_base + j * stride_m_k)
combined_mask = mask_val * mask_j
is_diagonal = tl.where(j == k, 1.0, 0.0)
self_grad = teacher_prob * (1.0 - student_prob_k)
cross_grad = -teacher_prob * student_prob_j
grad_val = (
-(self_grad * is_diagonal + cross_grad * (1.0 - is_diagonal))
* scale
* combined_mask
)
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)


@triton.jit
Expand Down Expand Up @@ -338,7 +290,6 @@ def forward(
)
kd_loss = token_losses.sum()

# pylint: disable=duplicate-code
# Apply temperature scaling
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
Expand Down Expand Up @@ -376,7 +327,7 @@ def backward(ctx, grad_output):
# Compute scaling factor
scale = grad_output.item()

# Apply temperature scaling
# Apply temperature scaling from forward pass
if kd_temperature != 1.0:
scale = scale * (kd_temperature**2)

Expand All @@ -386,7 +337,8 @@ def backward(ctx, grad_output):
else:
scale = scale / float(target_mask.sum().item())

# If we used temperature scaling in the forward pass, we need to apply it in the backward pass
# Apply chain rule for temperature scaling (1/temperature)
# This comes from d(logits/temperature)/d(logits) = 1/temperature
if kd_temperature != 1.0:
scale = scale / kd_temperature

Expand Down Expand Up @@ -434,9 +386,9 @@ def backward(ctx, grad_output):
grid = (batch_size * seq_len,)
grad_softmax_kernel[grid](
grad_student_logits.contiguous(),
student_logits.contiguous(),
target_token_ids.contiguous(),
teacher_probs.contiguous(),
student_probs.contiguous(),
target_mask.contiguous(),
batch_size,
seq_len,
Expand All @@ -446,15 +398,15 @@ def backward(ctx, grad_output):
grad_student_logits.stride(0),
grad_student_logits.stride(1),
grad_student_logits.stride(2),
student_logits.stride(0),
student_logits.stride(1),
student_logits.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
teacher_probs.stride(0),
teacher_probs.stride(1),
teacher_probs.stride(2),
student_probs.stride(0),
student_probs.stride(1),
student_probs.stride(2),
target_mask.stride(0),
target_mask.stride(1),
target_mask.stride(2),
Expand Down
67 changes: 67 additions & 0 deletions src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Optimized Triton kernels for logsumexp
"""
# pylint: disable=invalid-name,unused-argument
import triton
import triton.language as tl


# Helper function for computing logsumexp
@triton.jit
def logsumexp_kernel(
logits_ptr,
output_ptr,
B,
S,
V, # batch size, seq len, vocab size
stride_b,
stride_s,
stride_v,
out_stride_b,
out_stride_s,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
# pylint: disable=duplicate-code
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S

# Bounds check
if batch_idx >= B or seq_idx >= S:
return

# Pointers
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s

# Find maximum for numerical stability
max_val = -float("inf")
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size

logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))

# Compute sum of exp(logit - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size

logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)

# Compute logsumexp
result = max_val + tl.log(sum_exp)

# Store result
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
Loading

0 comments on commit c011405

Please sign in to comment.