|
|
from itertools import chain |
|
|
from typing import Callable, cast, NamedTuple, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.distributed.device_mesh import _get_device_handle |
|
|
from torch.distributed.distributed_c10d import _resolve_process_group, ReduceOp |
|
|
from torch.distributed.tensor import DTensor |
|
|
|
|
|
from ._fsdp_common import ( |
|
|
_get_dim0_padded_size, |
|
|
_raise_assert_with_print, |
|
|
_to_dtype_if_needed, |
|
|
compiled_autograd_enabled, |
|
|
) |
|
|
from ._fsdp_param import FSDPParam, ShardedState |
|
|
|
|
|
|
|
|
class AllGatherResult(NamedTuple): |
|
|
all_gather_output: torch.Tensor |
|
|
all_gather_event: Optional[torch.Event] |
|
|
all_gather_work: Optional[dist.distributed_c10d.Work] |
|
|
|
|
|
param_all_gather_input_dtypes: list[list[torch.dtype]] |
|
|
|
|
|
param_all_gather_input_numels: list[list[int]] |
|
|
|
|
|
|
|
|
all_gather_input_split_sizes: list[int] |
|
|
|
|
|
|
|
|
def allocate_memory( |
|
|
size: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
group: dist.ProcessGroup, |
|
|
from_process_group: bool, |
|
|
) -> torch.Tensor: |
|
|
if from_process_group: |
|
|
backend = group._get_backend(device) |
|
|
if backend.supports_tensor_alloc(device): |
|
|
return backend.allocate_tensor(size, dtype=dtype, device=device) |
|
|
return torch.empty((size,), dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
lib = torch.library.Library("fsdp", "FRAGMENT") |
|
|
|
|
|
lib.define( |
|
|
""" |
|
|
all_gather_copy_in( |
|
|
Tensor[] all_gather_inputs, |
|
|
SymInt[] inp_split_sizes, |
|
|
SymInt all_gather_input_numel, |
|
|
SymInt world_size, |
|
|
SymInt rank, |
|
|
ScalarType dtype, |
|
|
Device device, |
|
|
str group_name, |
|
|
bool allocate_memory_from_process_group |
|
|
) -> (Tensor, Tensor) |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
@torch.library.impl(lib, "all_gather_copy_in", "Meta") |
|
|
def all_gather_copy_in_meta( |
|
|
all_gather_inputs: list[torch.Tensor], |
|
|
inp_split_sizes: list[int], |
|
|
all_gather_input_numel: int, |
|
|
world_size: int, |
|
|
rank: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
group_name: str, |
|
|
allocate_memory_from_process_group: bool, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
all_gather_output = torch.empty( |
|
|
(all_gather_input_numel * world_size,), dtype=dtype, device="meta" |
|
|
) |
|
|
all_gather_input = all_gather_output.narrow( |
|
|
0, all_gather_input_numel * rank, all_gather_input_numel |
|
|
) |
|
|
return all_gather_input, all_gather_output |
|
|
|
|
|
|
|
|
@torch.library.impl(lib, "all_gather_copy_in", "CUDA") |
|
|
@torch.library.impl(lib, "all_gather_copy_in", "XPU") |
|
|
@torch.library.impl(lib, "all_gather_copy_in", "HPU") |
|
|
@torch.library.impl(lib, "all_gather_copy_in", "CPU") |
|
|
@torch.library.impl(lib, "all_gather_copy_in", "MTIA") |
|
|
@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") |
|
|
def all_gather_copy_in_cuda( |
|
|
all_gather_inputs: list[torch.Tensor], |
|
|
inp_split_sizes: list[int], |
|
|
all_gather_input_numel: int, |
|
|
world_size: int, |
|
|
rank: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
group_name: str, |
|
|
allocate_memory_from_process_group: bool, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
all_gather_output = allocate_memory( |
|
|
all_gather_input_numel * world_size, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
group=_resolve_process_group(group_name), |
|
|
from_process_group=allocate_memory_from_process_group, |
|
|
) |
|
|
all_gather_input = all_gather_output.narrow( |
|
|
0, all_gather_input_numel * rank, all_gather_input_numel |
|
|
) |
|
|
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) |
|
|
with torch.no_grad(): |
|
|
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) |
|
|
return all_gather_input, all_gather_output |
|
|
|
|
|
|
|
|
lib.define( |
|
|
"split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" |
|
|
) |
|
|
|
|
|
|
|
|
@torch.library.impl(lib, "split_with_sizes_copy", "Meta") |
|
|
@torch.library.impl(lib, "split_with_sizes_copy", "CUDA") |
|
|
@torch.library.impl(lib, "split_with_sizes_copy", "XPU") |
|
|
@torch.library.impl(lib, "split_with_sizes_copy", "HPU") |
|
|
@torch.library.impl(lib, "split_with_sizes_copy", "CPU") |
|
|
@torch.library.impl(lib, "split_with_sizes_copy", "MTIA") |
|
|
@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1") |
|
|
def split_with_sizes_copy( |
|
|
all_gather_output: torch.Tensor, |
|
|
all_gather_input_split_sizes: list[int], |
|
|
dim: int, |
|
|
out: list[torch.Tensor], |
|
|
) -> None: |
|
|
torch.split_with_sizes_copy( |
|
|
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out |
|
|
) |
|
|
|
|
|
|
|
|
lib.define( |
|
|
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" |
|
|
) |
|
|
|
|
|
|
|
|
@torch.library.impl(lib, "chunk_cat", "Meta") |
|
|
@torch.library.impl(lib, "chunk_cat", "CUDA") |
|
|
@torch.library.impl(lib, "chunk_cat", "XPU") |
|
|
@torch.library.impl(lib, "chunk_cat", "HPU") |
|
|
@torch.library.impl(lib, "chunk_cat", "CPU") |
|
|
@torch.library.impl(lib, "chunk_cat", "MTIA") |
|
|
@torch.library.impl(lib, "chunk_cat", "PrivateUse1") |
|
|
def chunk_cat( |
|
|
tensors: list[torch.Tensor], |
|
|
dim: int, |
|
|
num_chunks: int, |
|
|
out: torch.Tensor, |
|
|
) -> None: |
|
|
torch._chunk_cat(tensors, dim, num_chunks, out=out) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def foreach_all_gather( |
|
|
fsdp_params: list[FSDPParam], |
|
|
group: dist.ProcessGroup, |
|
|
async_op: bool, |
|
|
all_gather_copy_in_stream: torch.Stream, |
|
|
all_gather_stream: torch.Stream, |
|
|
device: torch.device, |
|
|
allocate_memory_from_process_group: bool = False, |
|
|
) -> Optional[AllGatherResult]: |
|
|
world_size, rank = group.size(), group.rank() |
|
|
device_handle = _get_device_handle(device.type) |
|
|
with device_handle.stream(all_gather_copy_in_stream): |
|
|
param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params) |
|
|
( |
|
|
param_all_gather_input_dtypes, |
|
|
param_all_gather_input_numels, |
|
|
dtype, |
|
|
) = _get_all_gather_input_metadatas(param_all_gather_inputs) |
|
|
if dtype == torch.uint8: |
|
|
all_gather_inputs = [ |
|
|
t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts |
|
|
] |
|
|
else: |
|
|
all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)] |
|
|
inp_split_sizes = [t.numel() for t in all_gather_inputs] |
|
|
all_gather_input_numel = sum(inp_split_sizes) |
|
|
all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( |
|
|
all_gather_inputs, |
|
|
inp_split_sizes, |
|
|
all_gather_input_numel, |
|
|
world_size, |
|
|
rank, |
|
|
dtype, |
|
|
device, |
|
|
group.group_name, |
|
|
allocate_memory_from_process_group, |
|
|
) |
|
|
del param_all_gather_inputs |
|
|
all_gather_stream.wait_stream(all_gather_copy_in_stream) |
|
|
with device_handle.stream(all_gather_stream): |
|
|
all_gather_work = dist.all_gather_into_tensor( |
|
|
output_tensor=all_gather_output, |
|
|
input_tensor=all_gather_input, |
|
|
group=group, |
|
|
async_op=async_op, |
|
|
) |
|
|
all_gather_event = all_gather_stream.record_event() |
|
|
return AllGatherResult( |
|
|
all_gather_output, |
|
|
all_gather_event, |
|
|
all_gather_work, |
|
|
param_all_gather_input_dtypes, |
|
|
param_all_gather_input_numels, |
|
|
inp_split_sizes, |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _get_param_all_gather_inputs( |
|
|
fsdp_params: list[FSDPParam], |
|
|
) -> list[list[torch.Tensor]]: |
|
|
if compiled_autograd_enabled(): |
|
|
return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def use_foreach_copy(fsdp_param: FSDPParam) -> bool: |
|
|
return ( |
|
|
fsdp_param.param_dtype is not None |
|
|
and not fsdp_param.offload_to_cpu |
|
|
and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather") |
|
|
) |
|
|
|
|
|
param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params] |
|
|
foreach_copy_indices: list[int] = [] |
|
|
foreach_copy_inputs: list[torch.Tensor] = [] |
|
|
foreach_copy_input_numels: list[int] = [] |
|
|
|
|
|
|
|
|
|
|
|
for i, fsdp_param in enumerate(fsdp_params): |
|
|
if use_foreach_copy(fsdp_param): |
|
|
foreach_copy_indices.append(i) |
|
|
all_gather_input = ( |
|
|
fsdp_param._sharded_param_data |
|
|
if fsdp_param.sharded_state == ShardedState.SHARDED |
|
|
else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data) |
|
|
) |
|
|
foreach_copy_inputs.append(all_gather_input) |
|
|
foreach_copy_input_numels.append(all_gather_input.numel()) |
|
|
else: |
|
|
param_all_gather_inputs[i] = fsdp_param.all_gather_inputs |
|
|
|
|
|
|
|
|
if foreach_copy_inputs: |
|
|
fsdp_param_0 = fsdp_params[foreach_copy_indices[0]] |
|
|
param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device |
|
|
flat_foreach_copy_input = torch.empty( |
|
|
(sum(foreach_copy_input_numels),), device=device, dtype=param_dtype |
|
|
) |
|
|
splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels) |
|
|
torch._foreach_copy_(splits, foreach_copy_inputs) |
|
|
for i, split in zip(foreach_copy_indices, splits): |
|
|
param_all_gather_inputs[i] = [split] |
|
|
|
|
|
return param_all_gather_inputs |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def foreach_all_gather_copy_out( |
|
|
all_gather_result: AllGatherResult, |
|
|
fsdp_params: list[FSDPParam], |
|
|
group: dist.ProcessGroup, |
|
|
) -> None: |
|
|
( |
|
|
all_gather_output, |
|
|
all_gather_event, |
|
|
all_gather_work, |
|
|
param_all_gather_input_dtypes, |
|
|
param_all_gather_input_numels, |
|
|
all_gather_input_split_sizes, |
|
|
) = all_gather_result |
|
|
_dtype, device = all_gather_output.dtype, all_gather_output.device |
|
|
device_handle = _get_device_handle(device.type) |
|
|
if all_gather_event is not None: |
|
|
device_handle.current_stream().wait_event(all_gather_event) |
|
|
if isinstance(all_gather_work, dist.distributed_c10d.Work): |
|
|
all_gather_work.wait() |
|
|
world_size, device = group.size(), all_gather_output.device |
|
|
|
|
|
split_with_sizes_out: list[torch.Tensor] = [] |
|
|
shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = [] |
|
|
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( |
|
|
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params |
|
|
): |
|
|
|
|
|
|
|
|
force_recreate = compiled_autograd_enabled() |
|
|
fsdp_param.init_all_gather_outputs( |
|
|
all_gather_input_numels, |
|
|
all_gather_input_dtypes, |
|
|
world_size, |
|
|
device, |
|
|
force_recreate=force_recreate, |
|
|
) |
|
|
if not force_recreate: |
|
|
fsdp_param.alloc_all_gather_outputs() |
|
|
param_all_gather_outputs = fsdp_param.all_gather_outputs |
|
|
if fsdp_param.fsdp_placement.dim != 0: |
|
|
|
|
|
|
|
|
param_all_gather_outputs = [ |
|
|
torch.empty_like(t) for t in param_all_gather_outputs |
|
|
] |
|
|
shard_i_copy_infos.append((fsdp_param, param_all_gather_outputs)) |
|
|
split_with_sizes_out.extend(param_all_gather_outputs) |
|
|
|
|
|
all_gather_output = all_gather_output.view(world_size, -1) |
|
|
if all_gather_output.dtype == torch.uint8: |
|
|
out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out] |
|
|
else: |
|
|
out = [t.view(world_size, -1) for t in split_with_sizes_out] |
|
|
|
|
|
|
|
|
if torch._dynamo.is_compiling(): |
|
|
|
|
|
|
|
|
|
|
|
non_inference_outs = [] |
|
|
else: |
|
|
non_inference_outs = [o for o in out if not o.is_inference()] |
|
|
|
|
|
if len(non_inference_outs) > 0: |
|
|
with torch.autograd._unsafe_preserve_version_counter(tuple(non_inference_outs)): |
|
|
torch.ops.fsdp.split_with_sizes_copy( |
|
|
all_gather_output, all_gather_input_split_sizes, dim=1, out=out |
|
|
) |
|
|
else: |
|
|
torch.ops.fsdp.split_with_sizes_copy( |
|
|
all_gather_output, all_gather_input_split_sizes, dim=1, out=out |
|
|
) |
|
|
|
|
|
for fsdp_param, param_all_gather_outputs in shard_i_copy_infos: |
|
|
|
|
|
shard_dim = fsdp_param.fsdp_placement.dim |
|
|
|
|
|
with torch.autograd._unsafe_preserve_version_counter( |
|
|
tuple(fsdp_param.all_gather_outputs) |
|
|
): |
|
|
for param_all_gather_output, target_all_gather_output in zip( |
|
|
param_all_gather_outputs, fsdp_param.all_gather_outputs |
|
|
): |
|
|
padded_sharded_size = ( |
|
|
fsdp_param.padded_sharded_param_size |
|
|
if fsdp_param.sharded_state == ShardedState.SHARDED |
|
|
else cast( |
|
|
torch.Tensor, fsdp_param._sharded_post_forward_param_data |
|
|
).size() |
|
|
) |
|
|
pre_param_size = list(padded_sharded_size) |
|
|
pre_param_size[0] *= world_size |
|
|
chunks = torch.chunk( |
|
|
param_all_gather_output.view(pre_param_size), world_size, dim=0 |
|
|
) |
|
|
post_param_size = list(padded_sharded_size) |
|
|
post_param_size[shard_dim] *= world_size |
|
|
cat_out = target_all_gather_output.view(post_param_size) |
|
|
torch.cat(chunks, dim=shard_dim, out=cat_out) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def foreach_reduce( |
|
|
fsdp_params: list[FSDPParam], |
|
|
unsharded_grads: list[torch.Tensor], |
|
|
reduce_scatter_group: dist.ProcessGroup, |
|
|
reduce_scatter_stream: torch.Stream, |
|
|
orig_dtype: Optional[torch.dtype], |
|
|
reduce_dtype: Optional[torch.dtype], |
|
|
device: torch.device, |
|
|
gradient_divide_factor: Optional[float], |
|
|
all_reduce_group: Optional[dist.ProcessGroup], |
|
|
all_reduce_stream: torch.Stream, |
|
|
all_reduce_grads: bool, |
|
|
partial_reduce_output: Optional[torch.Tensor], |
|
|
all_reduce_hook: Optional[Callable[[torch.Tensor], None]], |
|
|
allocate_memory_from_process_group: bool = False, |
|
|
force_sum_reduction_for_comms: bool = False, |
|
|
) -> tuple[ |
|
|
torch.Tensor, |
|
|
torch.Event, |
|
|
torch.Event, |
|
|
Optional[torch.Tensor], |
|
|
Optional[torch.Event], |
|
|
Optional[torch.Tensor], |
|
|
]: |
|
|
""" |
|
|
``unsharded_grads`` owns the references to the gradients computed by |
|
|
autograd, so clearing the list frees the gradients. |
|
|
""" |
|
|
grad_dtypes = {grad.dtype for grad in unsharded_grads} |
|
|
if len(grad_dtypes) != 1: |
|
|
|
|
|
|
|
|
_raise_assert_with_print( |
|
|
f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}" |
|
|
) |
|
|
grad_dtype = unsharded_grads[0].dtype |
|
|
reduce_dtype = reduce_dtype or grad_dtype |
|
|
(predivide_factor, postdivide_factor, reduce_scatter_op, all_reduce_op) = ( |
|
|
_get_gradient_divide_factors( |
|
|
reduce_scatter_group, |
|
|
all_reduce_group, |
|
|
reduce_dtype, |
|
|
device.type, |
|
|
gradient_divide_factor, |
|
|
force_sum_reduction_for_comms, |
|
|
) |
|
|
) |
|
|
world_size = reduce_scatter_group.size() |
|
|
for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)): |
|
|
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0: |
|
|
continue |
|
|
assert unsharded_grad.size(shard_dim) % world_size == 0, ( |
|
|
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" |
|
|
) |
|
|
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim) |
|
|
unsharded_grads[i] = torch.cat(chunks, dim=0) |
|
|
padded_unsharded_sizes = tuple( |
|
|
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads |
|
|
) |
|
|
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) |
|
|
reduce_scatter_output_numel = reduce_scatter_input_numel // world_size |
|
|
reduce_scatter_input = allocate_memory( |
|
|
reduce_scatter_input_numel, |
|
|
dtype=reduce_dtype, |
|
|
device=device, |
|
|
group=reduce_scatter_group, |
|
|
from_process_group=allocate_memory_from_process_group, |
|
|
) |
|
|
device_handle = _get_device_handle(device.type) |
|
|
foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size) |
|
|
current_stream = device_handle.current_stream() |
|
|
|
|
|
unsharded_grads.clear() |
|
|
reduce_scatter_stream.wait_stream(current_stream) |
|
|
all_reduce_input = None |
|
|
all_reduce_event = None |
|
|
with device_handle.stream(reduce_scatter_stream): |
|
|
reduce_output = allocate_memory( |
|
|
reduce_scatter_output_numel, |
|
|
dtype=reduce_dtype, |
|
|
device=device, |
|
|
group=reduce_scatter_group, |
|
|
from_process_group=allocate_memory_from_process_group, |
|
|
) |
|
|
_div_if_needed(reduce_scatter_input, predivide_factor) |
|
|
dist.reduce_scatter_tensor( |
|
|
output=reduce_output, |
|
|
input=reduce_scatter_input, |
|
|
group=reduce_scatter_group, |
|
|
op=reduce_scatter_op, |
|
|
) |
|
|
reduce_scatter_event = reduce_scatter_stream.record_event() |
|
|
post_reduce_stream = reduce_scatter_stream |
|
|
if all_reduce_group is not None: |
|
|
|
|
|
if not all_reduce_grads: |
|
|
if partial_reduce_output is not None: |
|
|
partial_reduce_output += reduce_output |
|
|
else: |
|
|
partial_reduce_output = reduce_output |
|
|
return ( |
|
|
reduce_scatter_input, |
|
|
reduce_scatter_event, |
|
|
post_reduce_stream.record_event(), |
|
|
all_reduce_input, |
|
|
all_reduce_event, |
|
|
partial_reduce_output, |
|
|
) |
|
|
if partial_reduce_output is not None: |
|
|
reduce_output += partial_reduce_output |
|
|
post_reduce_stream = all_reduce_stream |
|
|
all_reduce_stream.wait_stream(reduce_scatter_stream) |
|
|
with device_handle.stream(all_reduce_stream): |
|
|
dist.all_reduce( |
|
|
reduce_output, |
|
|
group=all_reduce_group, |
|
|
op=all_reduce_op, |
|
|
) |
|
|
all_reduce_input = reduce_output |
|
|
all_reduce_event = all_reduce_stream.record_event() |
|
|
|
|
|
|
|
|
if all_reduce_hook is not None: |
|
|
|
|
|
|
|
|
|
|
|
post_reduce_stream = all_reduce_stream |
|
|
all_reduce_stream.wait_stream(reduce_scatter_stream) |
|
|
with device_handle.stream(all_reduce_stream): |
|
|
all_reduce_hook(reduce_output) |
|
|
|
|
|
|
|
|
with device_handle.stream(post_reduce_stream): |
|
|
_div_if_needed(reduce_output, postdivide_factor) |
|
|
reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) |
|
|
|
|
|
flat_grad_offset = 0 |
|
|
for padded_unsharded_size, fsdp_param in zip( |
|
|
padded_unsharded_sizes, fsdp_params |
|
|
): |
|
|
|
|
|
|
|
|
new_sharded_grad = torch.as_strided( |
|
|
reduce_output, |
|
|
size=fsdp_param.sharded_size, |
|
|
stride=fsdp_param.contiguous_sharded_stride, |
|
|
storage_offset=flat_grad_offset, |
|
|
) |
|
|
to_accumulate_grad = fsdp_param.sharded_param.grad is not None |
|
|
if fsdp_param.offload_to_cpu: |
|
|
|
|
|
|
|
|
|
|
|
non_blocking = fsdp_param.pin_memory and not to_accumulate_grad |
|
|
|
|
|
|
|
|
|
|
|
new_sharded_grad = new_sharded_grad.to( |
|
|
torch.device("cpu"), non_blocking=non_blocking |
|
|
) |
|
|
if non_blocking: |
|
|
|
|
|
|
|
|
fsdp_param.grad_offload_event = reduce_scatter_stream.record_event() |
|
|
if to_accumulate_grad: |
|
|
assert isinstance(fsdp_param.sharded_param.grad, DTensor) |
|
|
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad |
|
|
else: |
|
|
new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor( |
|
|
new_sharded_grad |
|
|
) |
|
|
fsdp_param.sharded_param.grad = new_sharded_dtensor_grad |
|
|
if not compiled_autograd_enabled(): |
|
|
for hook in ( |
|
|
getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {}) |
|
|
or {} |
|
|
).values(): |
|
|
hook(fsdp_param.sharded_param) |
|
|
padded_sharded_numel = padded_unsharded_size.numel() // world_size |
|
|
flat_grad_offset += padded_sharded_numel |
|
|
post_reduce_event = post_reduce_stream.record_event() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
reduce_scatter_input, |
|
|
reduce_scatter_event, |
|
|
post_reduce_event, |
|
|
all_reduce_input, |
|
|
all_reduce_event, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
def foreach_reduce_scatter_copy_in( |
|
|
unsharded_grads: list[torch.Tensor], |
|
|
reduce_scatter_input: torch.Tensor, |
|
|
world_size: int, |
|
|
) -> None: |
|
|
reduce_scatter_input = reduce_scatter_input.view(world_size, -1) |
|
|
torch.ops.fsdp.chunk_cat( |
|
|
unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input |
|
|
) |
|
|
|
|
|
|
|
|
def _get_all_gather_input_metadatas( |
|
|
param_all_gather_inputs: list[list[torch.Tensor]], |
|
|
) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]: |
|
|
param_all_gather_input_dtypes: list[list[torch.dtype]] = [] |
|
|
param_all_gather_input_numels: list[list[int]] = [] |
|
|
all_gather_dtype = param_all_gather_inputs[0][0].dtype |
|
|
for all_gather_inputs in param_all_gather_inputs: |
|
|
input_dtypes: list[torch.dtype] = [] |
|
|
input_numels: list[int] = [] |
|
|
for all_gather_input in all_gather_inputs: |
|
|
if all_gather_input.dtype != all_gather_dtype: |
|
|
all_gather_dtype = torch.uint8 |
|
|
input_dtypes.append(all_gather_input.dtype) |
|
|
input_numels.append(all_gather_input.numel()) |
|
|
param_all_gather_input_dtypes.append(input_dtypes) |
|
|
param_all_gather_input_numels.append(input_numels) |
|
|
return ( |
|
|
param_all_gather_input_dtypes, |
|
|
param_all_gather_input_numels, |
|
|
all_gather_dtype, |
|
|
) |
|
|
|
|
|
|
|
|
def _get_gradient_divide_factors( |
|
|
reduce_scatter_group: dist.ProcessGroup, |
|
|
all_reduce_group: Optional[dist.ProcessGroup], |
|
|
reduce_dtype: torch.dtype, |
|
|
device_type: str = "", |
|
|
factor: Optional[float] = None, |
|
|
force_sum_reduction_for_comms: bool = False, |
|
|
) -> tuple[ |
|
|
Optional[float], |
|
|
Optional[float], |
|
|
Union[dist.ReduceOp, dist.ReduceOp.RedOpType], |
|
|
Union[dist.ReduceOp, dist.ReduceOp.RedOpType], |
|
|
]: |
|
|
|
|
|
if device_type == "mtia": |
|
|
force_sum_reduction_for_comms = True |
|
|
|
|
|
|
|
|
|
|
|
overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16) |
|
|
|
|
|
data_parallel_size = reduce_scatter_group.size() |
|
|
if all_reduce_group is not None: |
|
|
data_parallel_size *= all_reduce_group.size() |
|
|
|
|
|
if factor is None: |
|
|
factor = float(data_parallel_size) |
|
|
|
|
|
if not overflow_risk and not force_sum_reduction_for_comms: |
|
|
if factor == data_parallel_size: |
|
|
|
|
|
|
|
|
return None, None, ReduceOp.AVG, ReduceOp.AVG |
|
|
else: |
|
|
reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) |
|
|
return None, None, reduce_scatter_op, ReduceOp.SUM |
|
|
|
|
|
pre_factor: Optional[float] |
|
|
if overflow_risk: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pre_factor = 1 |
|
|
while factor % pre_factor == 0 and factor / pre_factor > pre_factor: |
|
|
pre_factor *= 2 |
|
|
post_factor = factor / pre_factor |
|
|
else: |
|
|
|
|
|
pre_factor, post_factor = None, factor |
|
|
|
|
|
return pre_factor, post_factor, ReduceOp.SUM, ReduceOp.SUM |
|
|
|
|
|
|
|
|
def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None: |
|
|
if div_factor is not None and div_factor != 1: |
|
|
tensor.div_(div_factor) |
|
|
|