def get_batch(
    dataset: npt.NDArray, batch_size: int, context_length: int, device: DeviceLikeType
) -> tuple[Tensor, Tensor]:
    # Why are we writing this by hand instead of using DatasetLoader?
    # Note we are using "sampling", instead of multiple workers pre-fetching.
    # Called repeatedly each training step instead of using a stateful dataloader.
    # Works because:
    # (1) the full token array fits in RAM so sampling is just a random
    #   index + slice — no I/O to hide with prefetch workers;
    # (2) LLM training never "finishes" an epoch, so stateless random sampling is simpler and equivalent.
    # Note how this is different from vision tasks, where each input samples may be large (while here it's just text)
 
    dataset_token_count = len(dataset)
    # Exclusive upper bound since we need to shift by one for targets
    start_indices = np.random.randint(0, dataset_token_count - context_length, size=batch_size)
    # we want start_indices:start_indices+context_length
    # Construct one manually with advanced indexing
    # Broadcasting!
    indexing_arr = einx.add(
        "batch_size, context_length -> batch_size context_length", start_indices, np.arange(context_length)
    )
    # We can put them in pinned memory and do async copy.
    # But then we add memory pressure and slower allocation.
    # There exist stream sync but we are not overlapping data loading and training anyway.
    return (torch.from_numpy(dataset[indexing_arr]).to(device), torch.from_numpy(dataset[indexing_arr + 1]).to(device))
 
def gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float) -> None:
    # Not that trivial
 
    # We can't clip for each p in parameter. That'll change the gradient direction.
    # We can't just clip the param themselves, since the max is specified on l2, not l1
    # So we need to work backwards: how can I scale my norm with a variable so it stays below?
    
    # We could do sum of torch.square ourselves to do param-norm,
    # or we could use a nice property of L2 norm (not for L3 and others),
    # norm of subnorm is still that norm (that's because sqrt cancels ^2).
    # Similar performance
    param_norm = torch.linalg.vector_norm(
        torch.tensor([torch.linalg.vector_norm(p.grad) for p in parameters if p.grad is not None])
    )
 
    # We can get this by solving ||g x|| = M
    gradient_scale_factor = max_l2_norm / (param_norm + 1e-6)
 
    for p in parameters:
        if p.grad is None:
            continue
        # We only need to clip, bu not pull small gradients up
        # Avoids one if.
        gradient_scale_factor = torch.clip(gradient_scale_factor, max=1.0)
        p.grad *= gradient_scale_factor