Skip to content

Instantly share code, notes, and snippets.

@nov05
Last active February 8, 2025 06:40
Show Gist options
  • Save nov05/1bdc15eda0e781640b46ab28d38f45bd to your computer and use it in GitHub Desktop.
Save nov05/1bdc15eda0e781640b46ab28d38f45bd to your computer and use it in GitHub Desktop.
  • ⚠️🟢 Issue: training error
[1,mpirank:0,algo-1]<stderr>:../aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
[1,mpirank:0,algo-1]<stderr>:../aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [6[1,mpirank:0,algo-1]<stderr>:,0,0] Assertion `t >= 0 && t < n_classes` failed.
[1,mpirank:0,algo-1]<stderr>:../aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.
...
[1,mpirank:1,algo-2]<stdout>:  File "train.py", line 675, in <module>
[1,mpirank:1,algo-2]<stdout>:    main(task)
[1,mpirank:1,algo-2]<stdout>:  File "train.py", line 572, in main
[1,mpirank:1,algo-2]<stdout>:    train(task)
[1,mpirank:1,algo-2]<stdout>:  File "train.py", line 277, in train
[1,mpirank:1,algo-2]<stdout>:    loss.backward()
[1,mpirank:1,algo-2]<stdout>:  File "/opt/conda/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
[1,mpirank:1,algo-2]<stdout>:    torch.autograd.backward(
[1,mpirank:1,algo-2]<stdout>:  File "/opt/conda/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
[1,mpirank:1,algo-2]<stdout>:    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[1,mpirank:1,algo-2]<stdout>:RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
[1,mpirank:0,algo-1]<stdout>:  File "/opt/conda/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
[1,mpirank:0,algo-1]<stdout>:    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[1,mpirank:0,algo-1]<stdout>:RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
  • ✅ Solution:
def label_transform(x):
    ## Original lables are (1,2,3,4,5)
    ## Convert to (0,1,2,3,4)
    return torch.tensor(int(x.decode())-1, dtype=torch.int64)
    
class WebDatasetDDP(IterableDataset):
    def __init__(self,
    ...
        super().__init__()
        self.dataset = (
        ...
        wds.DataPipeline(
        ...
        wds.map_tuple(
            key_transform,
            train_transform, 
            label_transform,  
        ),
        ...

  • ChatGPT (might be wrong)

The error message you're encountering, Assertion 't >= 0 && t < n_classes' failed, is related to the Negative Log Likelihood Loss (NLLLoss) in PyTorch, which is typically used for classification tasks. This assertion error means that one or more of the target labels in your dataset are either negative or greater than or equal to the number of classes (n_classes). In simpler terms, the model is encountering a target value (label) that is not within the valid range for classification.

Cause

The issue typically occurs because the target labels in your dataset have values that are outside the expected range of class indices. For a classification task with n_classes possible categories, valid target labels should range from 0 to n_classes - 1.

Solutions:

  1. Check the target labels: Ensure that the target labels in your dataset are integers and fall within the valid range [0, n_classes - 1].

    • If you have n_classes = 5, then the valid labels are 0, 1, 2, 3, 4.
    • Any label outside this range, such as -1 or 5, will trigger the assertion error.

    How to check: Add a check to validate the range of your target labels before passing them to the model.

    assert torch.all((target >= 0) & (target < n_classes)), "Invalid target label encountered!"
  2. Preprocess target labels: If your dataset contains labels that are outside the valid range, preprocess the target labels to ensure they are within the correct range.

    Example:

    # Assuming you need to remap the target labels to be in the range [0, n_classes-1]
    target = target.clamp(0, n_classes - 1)
  3. Check for invalid class weights: If you're using class weights in your loss function (e.g., passing weight to NLLLoss or CrossEntropyLoss), ensure that the class weights correspond to the number of classes and that they match the indices of the classes correctly.

  4. Verify number of classes in your dataset: Double-check that the n_classes variable in your code corresponds to the actual number of unique classes in your dataset. If n_classes is set incorrectly, targets outside this range will cause the error.

    You can inspect the number of unique classes using:

    unique_classes = torch.unique(targets)
    print(unique_classes)  # Ensure this matches with your n_classes
  5. Label smoothing: If you use label smoothing or other data augmentation techniques, ensure that these methods don't produce invalid target labels.

Example Fix:

If you're using a dataset with 5 classes (indexed from 0 to 4), but your dataset contains labels like 5 or -1, you could preprocess your targets as follows:

# Suppose you have targets that are incorrectly labeled
targets = torch.tensor([0, 1, 2, 5, -1])  # Example target with invalid labels

# Clamp target values between 0 and n_classes - 1
n_classes = 5
targets = targets.clamp(0, n_classes - 1)

print(targets)  # Output will be [0, 1, 2, 4, 0]

Debugging:

  • Log the target values right before passing them to the loss function to check if they are within the correct range.
  • Print the unique classes in the dataset to confirm if there are any out-of-bound labels:
    print(torch.unique(targets))  # Should print valid class indices only

Let me know how these steps work for you or if you need further assistance!

@nov05
Copy link
Author

nov05 commented Feb 7, 2025

⚠️🟢 issue: early stopping seems to cause error in AllGather and eventually timeout error?

✅ use dist.broadcast() and dist.barrier()

https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-intro.html

SMDDP AllGather collective operation
AllGather is a collective operation where each worker starts with an input buffer, and then concatenates or gathers the input buffers from all other workers into an output buffer.
AllGather is heavily used in distributed training techniques such as sharded data parallelism where each individual worker holds a fraction of a model, or a sharded layer. The workers call AllGather before forward and backward passes to reconstruct the sharded layers. The forward and backward passes continue onward after the parameters are all gathered. During the backward pass, each worker also calls ReduceScatter to collect (reduce) gradients and break (scatter) them into gradient shards to update the corresponding sharded layer.

distributed model parallelism, sharded data parallelism...

  • error
terminate called after throwing an instance of ':SMDDPTimeoutError:'
 what()
 #011Timeout: A call to 'allGather' has taken over 1800.000000 seconds. Terminating the distributed job.It might be one of the workers failed during forward and backward propagation and failed to call "allGather".
 #011Extend timeout using dist.init_process_group(timeout=timedelta(minutes=60)
 #011Extend timeout using dist.init(timeout=timedelta(minutes=60)
 #011or refer to the debugging guide. Verify that all ranks call
 #011collective operations in the same order and within timeout period.
 [algo-2:00105] *** Process received signal ***
 [algo-2:00105] Signal: Aborted (6)
 [algo-2:00105] Signal code:  (-6)
 [algo-2:00105] :[ 0] :/usr/lib/x86_64-linux-gnu/libpthread.so.0(+0x14420)[0x7f4e7a93c420]
 [algo-2:00105] :[ 1]
 /usr/lib/x86_64-linux-gnu/libc.so.6(gsignal+0xcb)[0x7f4e7a61d00b]
 [algo-2:00105] :[ 2] :/usr/lib/x86_64-linux-gnu/libc.so.6(abort+0x12b)[0x7f4e7a5fc859]
 [algo-2:00105]
 [ 3] :/opt/conda/lib/libstdc++.so.6(_ZN9__gnu_cxx27__verbose_terminate_handlerEv+0xc0)[0x7f4e4f792026]
 [ 4]
 /opt/conda/lib/libstdc++.so.6(+0xb0514)[0x7f4e4f790514]
 [ 5] :/opt/conda/lib/libstdc++.so.6(+0xb0566)[0x7f4e4f790566]
 [ 6] :/opt/conda/lib/libstdc++.so.6(__cxa_rethrow+0x0)[0x7f4e4f790758]
 [algo-2:00105] :[ 7] :/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_Z24default_timeout_callbacklSs+0x279)[0x7f4de9ea542f]
 [ 8] :/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZNSt17_Function_handlerIFvlSsEPS0_E9_M_invokeERKSt9_Any_dataOlOSs+0x70)[0x7f4de9ea7281]
 [algo-2:00105] :[ 9]
 /opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZNKSt8functionIFvlSsEEclElSs+0x65)[0x7f4de9ea6e0f]
 [algo-2:00105] :[10]
 /opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZN17TimeoutController9controlfnEv+0x471)[0x7f4de9ea59ef]
 [algo-2:00105] :[11] :/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZSt13__invoke_implIvM17TimeoutControllerFvvEPS0_JEET_St21__invoke_memfun_derefOT0_OT1_DpOT2_+0x69)[0x7f4de9ea7eb7]
 [algo-2:00105] :[12] :/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZSt8__invokeIM17TimeoutControllerFvvEJPS0_EENSt15__invoke_resultIT_JDpT0_EE4typeEOS5_DpOS6_+0x3b)[0x7f4de9ea7de6]
 [13] :/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZNSt6thread8_InvokerISt5tupleIJM17TimeoutControllerFvvEPS2_EEE9_M_invokeIJLm0ELm1EEEEvSt12_Index_tupleIJXspT_EEE+0x47)[0x7f4de9ea7d45]
 [algo-2:00105] :[14]
 /opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZNSt6thread8_InvokerISt5tupleIJM17TimeoutControllerFvvEPS2_EEEclEv+0x1c)[0x7f4de9ea7cfc]
 [algo-2:00105] :[15] :/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/lib/libsmddpcpp.so(_ZNSt6thread11_State_implINS_8_InvokerISt5tupleIJM17TimeoutControllerFvvEPS3_EEEEE6_M_runEv+0x20)[0x7f4de9ea7cdc]
 [algo-2:00105] :[16]
 /opt/conda/lib/libstdc++.so.6(+0xcda93)[0x7f4e4f7ada93]
 [algo-2:00105] :[17] :/usr/lib/x86_64-linux-gnu/libpthread.so.0(+0x8609)[0x7f4e7a930609]
 [algo-2:00105] :[18] :/usr/lib/x86_64-linux-gnu/libc.so.6(clone+0x43)[0x7f4e7a6f9353]
 [algo-2:00105] *** End of error message ***
 /opt/conda/bin/runwithenvvars: line 55:   105 Aborted                 (core dumped) $@
 --------------------------------------------------------------------------
 Primary job  terminated normally, but 1 process returned
 a non-zero exit code. Per user-direction, the job has been aborted.
 mpirun.real detected that one or more processes exited with non-zero status, thus causing
 the job to be terminated. The first process to do so was
 Process name: [[41139,1],1]
 Exit code:    134"
Command "mpirun --host algo-1:1,algo-2:1 -np 2 --allow-run-as-root --tag-output --oversubscribe -mca btl_tcp_if_include eth0 -mca oob_tcp_if_include eth0 -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent 2 -x NCCL_SOCKET_IFNAME=eth0 -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -x SMDATAPARALLEL_USE_HOMOGENEOUS=1 -x FI_PROVIDER=efa -x RDMAV_FORK_SAFE=1 -x LD_PRELOAD=/opt/conda/lib/python3.9/site-packages/gethostname.cpython-39-x86_64-linux-gnu.so -x SMDATAPARALLEL_SERVER_ADDR=algo-1 -x SMDATAPARALLEL_SERVER_PORT=7592 -x SAGEMAKER_INSTANCE_TYPE=ml.g4dn.xlarge smddprun /opt/conda/bin/python3.9 -m mpi4py train_10k.py --batch-size 128 --class-weights-dict 1=1.7004885993485341,2=0.9083079599826012,3=0.7832708177044261,4=0.8799831436999579,5=1.1137066666666666 --debug False --early-stopping-patience 5 --epochs 30 --lr-sched-step-size 10 --model-arch resnet34 --opt-learning-rate 0.1 --opt-momentum 0.9 --opt-type sgd --test-data-path s3://p5-amazon-bin-images/webdataset/test/test-shard-{000000..000001}.tar --test-data-size 1567 --train-data-path s3://p5-amazon-bin-images/webdataset/train/train-shard-{000000..000007}.tar --train-data-size 7308 --val-data-path s3://p5-amazon-bin-images/webdataset/val/val-shard-{000000..000001}.tar --val-data-size 1566 --wandb True"
2025-02-07 19:23:45,918 sagemaker-training-toolkit ERROR    Encountered exit_code 1
2025-02-07 19:23:45,916 sagemaker-training-toolkit INFO     Invoked on_terminate from psutil.wait_for_procs
2025-02-07 19:23:45,917 sagemaker-training-toolkit INFO     process psutil.Process(pid=86, name='orted', status='terminated', started='18:40:47') terminated with exit code None
2025-02-07 19:23:45,917 sagemaker-training-toolkit INFO     Reporting status for ORTEd process. gone: [psutil.Process(pid=86, name='orted', status='terminated', started='18:40:47')] alive: []
2025-02-07 19:23:45,917 sagemaker-training-toolkit INFO     Orted process exited

2025-02-07 19:23:56 Uploading - Uploading generated training model
2025-02-07 19:24:14 Failed - Training job failed

  • ChatGPT (might be wrong)

Yes, you're correct! In distributed training with multiple nodes, the state (like validation loss or early stopping condition) might not be the same across all nodes, leading to inconsistent stopping behavior if not handled properly.

Using dist.barrier() ensures synchronization across all nodes in the distributed environment. This allows you to make sure that once a stopping condition is met (e.g., early stopping), all nodes will reach a consensus and stop the training at the same time.

Here’s how you can use dist.barrier() for early stopping in your SageMaker Distributed Data Parallel (SDP) setup:

  1. Rank 0 Node (Master Node) Handling: Compute the early stopping condition (e.g., checking the validation loss) only on the master node (rank 0).

  2. Broadcast the Early Stopping Signal: Use dist.broadcast() to share the early stopping decision from the master node to all worker nodes.

  3. Sync Nodes with dist.barrier(): Ensure all nodes stop training together based on the early stopping decision.

Here’s a revised version of the code with these ideas implemented:

import torch
import torch.distributed as dist

# Assuming you're using SageMaker's Distributed Data Parallel (SDP)
from sagemaker_pytorch_ddp import DistributedDataParallel as DDP

# Your model, data, optimizer, etc.
model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

# Early stopping criteria
early_stopping_patience = 5
best_val_loss = float('inf')
epochs_without_improvement = 0
early_stop = False  # Flag for early stopping, will be broadcasted

# Distributed initialization
rank = dist.get_rank()  # Rank of the current node

for epoch in range(num_epochs):
    model.train()
    
    # Training loop
    for batch in train_loader:
        optimizer.zero_grad()
        output = model(batch)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

    # Validation step (assume val_loader and loss_fn are defined)
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for val_batch in val_loader:
            output = model(val_batch)
            loss = loss_fn(output, val_target)
            val_loss += loss.item()

    val_loss /= len(val_loader)

    # Early stopping check (only on rank 0)
    if rank == 0:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= early_stopping_patience:
            early_stop = True  # Early stop flag is set

    # Broadcast the early stop decision to all nodes
    early_stop_tensor = torch.tensor(int(early_stop), dtype=torch.int32)
    dist.broadcast(early_stop_tensor, src=0)

    # Update early_stop flag for all nodes
    early_stop = early_stop_tensor.item() == 1

    # If early stopping, break the training loop
    if early_stop:
        print(f"Early stopping at epoch {epoch} on node {rank}")
        dist.barrier()  # Synchronize all processes before stopping
        break

    # Ensure all nodes sync at the end of the epoch
    dist.barrier()

Key Points:

  • Rank 0 (Master Node): The early stopping logic (like tracking validation loss and deciding when to stop) runs only on the master node (rank == 0).

  • Broadcast Early Stop Signal: The master node (rank == 0) uses dist.broadcast() to share the early stopping decision with all other nodes. The other nodes receive this broadcast and set their early_stop flag accordingly.

  • Synchronization (dist.barrier()): All nodes are synchronized at the end of the epoch or right before early stopping to ensure that no node continues training after the early stop condition is met.

With this approach, you ensure that all nodes stop at the same time when the early stopping condition is met, preventing inconsistent behavior between nodes.

@nov05
Copy link
Author

nov05 commented Feb 7, 2025

⚠️🟢 issue: SMDDP broadcasting error:

 File "train_v1.py", line 449, in main
 dist.broadcast(braodcast_early_stop, src=0)  ## src is the process rank
 File "/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/torch/distributed.py", line 156, in wrapper
 return func(*args, **kwargs)
 File "/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/__init__.py", line 58, in wrapper
 File "/opt/conda/lib/python3.9/site-packages/smdistributed/dataparallel/torch/distributed.py", line 200, in broadcast
 return torchdst.broadcast(tensor, src=src, group=None, async_op=async_op)
 File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1400, in broadcast
 work = default_pg.broadcast([tensor], opts)
 ValueError: Expected devices to have indices, got cpu
  • ✅ solution: Move early_stop_tensor to Device: The early_stop_tensor is explicitly moved to the correct device using .to(device).

@nov05
Copy link
Author

nov05 commented Feb 8, 2025

⚠️🟢 Issue: another type of broadcasting error.

✅ I can now confirm that the issue was related to how to properly breaking the loop with the SageMaker-Torch DDP framework rather than WebDataset. I simply replaced dist.broadcast(broadcast_early_stop, src=0) with dist.all_reduce(broadcast_early_stop, op=dist.ReduceOp.SUM), and the error is gone.

[1,mpirank:0,algo-1]<stdout>:👉 VAL: Average loss: 439.5972, Accuracy: 252/1536 (16.41%)
[1,mpirank:0,algo-1]<stdout>:
[1,mpirank:0,algo-1]<stdout>:👉 Train Epoch: 1, Learning Rate: 0.1
[1,mpirank:0,algo-1]<stdout>:
[1,mpirank:1,algo-2]<stderr>:terminate called after throwing an instance of '[1,mpirank:1,algo-2]<stderr>:SMDDPTimeoutError[1,mpirank:1,algo-2]<stderr>:'
[1,mpirank:1,algo-2]<stderr>:  what():  [1,mpirank:1,algo-2]<stderr>:
[1,mpirank:1,algo-2]<stderr>:
[1,mpirank:1,algo-2]<stderr>:#011Timeout: A call to 'broadcast' has taken over 1800.000000 seconds. Terminating the distributed job.It might be one of the workers failed during forward and backward propagation and failed to call "broadcast".
[1,mpirank:1,algo-2]<stderr>:
[1,mpirank:1,algo-2]<stderr>:#011Extend timeout using dist.init_process_group(timeout=timedelta(minutes=60)
[1,mpirank:1,algo-2]<stderr>:#011Extend timeout using dist.init(timeout=timedelta(minutes=60)
[1,mpirank:1,algo-2]<stderr>:#011or refer to the debugging guide. Verify that all ranks call
[1,mpirank:1,algo-2]<stderr>:#011collective operations in the same order and within timeout period.
[1,mpirank:1,algo-2]<stderr>:
[1,mpirank:1,algo-2]<stderr>:[algo-2:00105] *** Process received signal ***
[1,mpirank:1,algo-2]<stderr>:[algo-2:00105] Signal: Aborted (6)
[1,mpirank:1,algo-2]<stderr>:[algo-2:00105] Signal code:  (-6)
UnexpectedStatusException: Error for Training job p5-amazon-bin-job-20250207-172703: Failed. Reason: 
AlgorithmError: SMDDPTimeoutError:
ExitCode 134
ErrorMessage "Exception ignored in: :<function Pipe.__del__ at 0x7f268daebc10>
 Traceback (most recent call last)
 File "/opt/conda/lib/python3.9/site-packages/webdataset/gopen.py", line 121, in __del__
 self.close()
 File "/opt/conda/lib/python3.9/site-packages/webdataset/gopen.py", line 109, in close
 self.wait_for_child()
 File "/opt/conda/lib/python3.9/site-packages/webdataset/gopen.py", line 83, in wait_for_child
 raise IOError(f"{self.args}: exit {self.status} (read) {info}")
 OSError: (('aws s3 cp s3://p5-amazon-bin-images/webdataset/train/train-shard-000004.tar -',), {'shell': True, 
'bufsize': 8192}): exit 1 (read) {}
 Exception ignored in: <function Pipe.__del__ at 0x7f20e0860c10>
 OSError: (('aws s3 cp s3://p5-amazon-bin-images/webdataset/train/train-shard-000005.tar -',), {'shell': True, 
'bufsize': 8192}): exit 1 (read) {}
 Exception ignored in: Traceback (most recent call last)
 <function Pipe.__del__ at 0x7f20e0860c10>  File "/opt/conda/lib/python3.. Check troubleshooting guide for common 
errors: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-python-sdk-troubleshooting.html
  • wait_for_child() hang the process? Likely not.
[1,mpirank:0,algo-1]<stderr>:  File "/opt/conda/lib/python3.9/site-packages/webdataset/gopen.py", line 109, in close
[1,mpirank:0,algo-1]<stderr>:    [1,mpirank:0,algo-1]<stderr>:self.wait_for_child()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment