Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created April 8, 2020 10:27
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel import SyncBatchNorm as ApexSyncBatchNorm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--apex', action='store_true')
args = parser.parse_args()
torch.manual_seed(2809)
# Setup DDP
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda:{}'.format(args.local_rank))
torch.distributed.init_process_group(
'nccl',
init_method='env://',
rank=args.local_rank,
)
# Setup model
if args.apex:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
ApexSyncBatchNorm(6)
)
else:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.SyncBatchNorm(6)
)
# Setup reference model
model_reference = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.BatchNorm2d(6)
)
with torch.no_grad():
model_reference[0].weight.copy_(model[0].weight)
model_reference[0].bias.copy_(model[0].bias)
model_reference.to(device)
# Setup SyncBN
#if not args.apex:
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(device)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
# Create random data
if args.local_rank == 0:
data = torch.randn(16, 3, 24, 24, device=device) * 100
else:
data = torch.randn(8, 3, 24, 24, device=device)
print('Input.sum() {}, .mean() {}, .std() {}, .min() {}, .max() {}, device {}'.format(
data.sum(), data.mean(), data.std(), data.min(), data.max(), data.device))
# DDP forward/backward
output = model(data)
print('DDP output.sum() {}, .mean() {}, .std() {}, .min() {}, .max() {}, device {}'.format(
output.sum(), output.mean(), output.std(), output.min(), output.max(), output.device))
output.sum().backward()
# Reference forward/backward
output_reference = model_reference(data)
print('Reference output.sum() {}, .mean() {}, .std() {}, .min() {}, .max() {}, device {}'.format(
output_reference.sum(), output_reference.mean(), output_reference.std(), output_reference.min(),
output_reference.max(), output_reference.device))
output_reference.sum().backward()
# Print stats
print('DDP stats ', model.module[1].running_mean, model.module[1].running_var)
print('Reference stats ', model_reference[1].running_mean, model_reference[1].running_var)
print('DDP grads ', model.module[0].weight.grad.abs().sum())
print('Reference grads ', model_reference[0].weight.grad.abs().sum())
@ptrblck
Copy link
Author

ptrblck commented Apr 8, 2020

Vanilla output

root@cd8d754a6069:/workspace/src# python -m torch.distributed.launch --nproc_per_node=4 repro.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Input.sum() 22933.48828125, .mean() 0.8294809460639954, .std() 100.45414733886719, .min() -408.5179443359375, .max() 426.8013000488281, device cuda:0
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:1
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:2
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:3
DDP output.sum() -51.18982696533203, .mean() -0.0009257419733330607, .std() 1.581018328666687, .min() -7.145139694213867, .max() 7.250021934509277, device cuda:0
DDP output.sum() 17.063417434692383, .mean() 0.0006171664572320879, .std() 0.016831880435347557, .min() -0.06054949387907982, .max() 0.0813681036233902, device cuda:1
DDP output.sum() 17.063417434692383, .mean() 0.0006171664572320879, .std() 0.016831880435347557, .min() -0.06054949387907982, .max() 0.0813681036233902, device cuda:2
DDP output.sum() 17.063417434692383, .mean() 0.0006171664572320879, .std() 0.016831880435347557, .min() -0.06054949387907982, .max() 0.0813681036233902, device cuda:3
Reference output.sum() 5.340576171875e-05, .mean() 9.65815982745255e-10, .std() 1.0000090599060059, .min() -4.512101173400879, .max() 4.594338417053223, device cuda:0
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:1
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:2
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:3
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:0') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:0')
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:1') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:1')
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:2') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:2')
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:3') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:3')
Reference stats  tensor([ 0.0710,  0.0012,  0.0002, -0.0112, -0.0646, -0.0910], device='cuda:0') tensor([276.5405, 218.3357, 278.8197, 324.7449, 337.3835, 348.7750],
       device='cuda:0')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:1') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:1')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:2') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:2')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:3') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:3')
DDP grads  tensor(512.7340, device='cuda:0')
DDP grads  tensor(512.7340, device='cuda:1')
DDP grads  tensor(512.7340, device='cuda:2')
DDP grads  tensor(512.7340, device='cuda:3')
Reference grads  tensor(0.0002, device='cuda:0')
Reference grads  tensor(0.0008, device='cuda:1')
Reference grads  tensor(0.0008, device='cuda:2')
Reference grads  tensor(0.0008, device='cuda:3')

Apex output

root@cd8d754a6069:/workspace/src# python -m torch.distributed.launch --nproc_per_node=4 repro.py --apex
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Input.sum() 22933.48828125, .mean() 0.8294809460639954, .std() 100.45414733886719, .min() -408.5179443359375, .max() 426.8013000488281, device cuda:0
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:1
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:2
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:3
DDP output.sum() -80.93080139160156, .mean() -0.0014635922852903605, .std() 1.99970543384552, .min() -9.040897369384766, .max() 9.16562557220459, device cuda:0
DDP output.sum() 13.488600730895996, .mean() 0.0004878689651377499, .std() 0.02039550431072712, .min() -0.07519855350255966, .max() 0.09870573878288269, device cuda:1
DDP output.sum() 13.488600730895996, .mean() 0.0004878689651377499, .std() 0.02039550431072712, .min() -0.07519855350255966, .max() 0.09870573878288269, device cuda:2
DDP output.sum() 13.488600730895996, .mean() 0.0004878689651377499, .std() 0.02039550431072712, .min() -0.07519855350255966, .max() 0.09870573878288269, device cuda:3
Reference output.sum() 5.340576171875e-05, .mean() 9.65815982745255e-10, .std() 1.0000090599060059, .min() -4.512101173400879, .max() 4.594338417053223, device cuda:0
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:1
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:2
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:3
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:0') tensor([69.8360, 55.2708, 70.3950, 81.8784, 85.0508, 87.8978], device='cuda:0')
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:1') tensor([69.8378, 55.2722, 70.3969, 81.8806, 85.0531, 87.9001], device='cuda:1')
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:2') tensor([69.8378, 55.2722, 70.3969, 81.8806, 85.0531, 87.9001], device='cuda:2')
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:3') tensor([69.8378, 55.2722, 70.3969, 81.8806, 85.0531, 87.9001], device='cuda:3')
Reference stats  tensor([ 0.0710,  0.0012,  0.0002, -0.0112, -0.0646, -0.0910], device='cuda:0') tensor([276.5405, 218.3357, 278.8197, 324.7449, 337.3835, 348.7750],
       device='cuda:0')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:1') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:1')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:2') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:2')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:3') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:3')
DDP grads  tensor(0.0019, device='cuda:0')
DDP grads  tensor(0.0019, device='cuda:1')
DDP grads  tensor(0.0019, device='cuda:2')
DDP grads  tensor(0.0019, device='cuda:3')
Reference grads  tensor(0.0002, device='cuda:0')
Reference grads  tensor(0.0008, device='cuda:1')
Reference grads  tensor(0.0008, device='cuda:2')
Reference grads  tensor(0.0008, device='cuda:3')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment