SAGE-MM-Molmo2-8B-SFT_RL / _fsdp_collectives.py
praeclarumjj3's picture
Upload model
a85213a verified
from itertools import chain
from typing import Callable, cast, NamedTuple, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.distributed_c10d import _resolve_process_group, ReduceOp
from torch.distributed.tensor import DTensor
from ._fsdp_common import (
_get_dim0_padded_size,
_raise_assert_with_print,
_to_dtype_if_needed,
compiled_autograd_enabled,
)
from ._fsdp_param import FSDPParam, ShardedState
class AllGatherResult(NamedTuple):
all_gather_output: torch.Tensor
all_gather_event: Optional[torch.Event]
all_gather_work: Optional[dist.distributed_c10d.Work]
# For each parameter, the all-gather input dtype for each input
param_all_gather_input_dtypes: list[list[torch.dtype]]
# For each parameter, the all-gather input numel for each input
param_all_gather_input_numels: list[list[int]]
# 1D flattened version of `param_all_gather_input_numels` saved to avoid
# CPU overhead from recomputing
all_gather_input_split_sizes: list[int]
def allocate_memory(
size: int,
dtype: torch.dtype,
device: torch.device,
group: dist.ProcessGroup,
from_process_group: bool,
) -> torch.Tensor:
if from_process_group:
backend = group._get_backend(device)
if backend.supports_tensor_alloc(device):
return backend.allocate_tensor(size, dtype=dtype, device=device)
return torch.empty((size,), dtype=dtype, device=device)
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib.define(
"""
all_gather_copy_in(
Tensor[] all_gather_inputs,
SymInt[] inp_split_sizes,
SymInt all_gather_input_numel,
SymInt world_size,
SymInt rank,
ScalarType dtype,
Device device,
str group_name,
bool allocate_memory_from_process_group
) -> (Tensor, Tensor)
"""
)
@torch.library.impl(lib, "all_gather_copy_in", "Meta")
def all_gather_copy_in_meta(
all_gather_inputs: list[torch.Tensor],
inp_split_sizes: list[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
group_name: str,
allocate_memory_from_process_group: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device="meta"
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
return all_gather_input, all_gather_output
@torch.library.impl(lib, "all_gather_copy_in", "CUDA")
@torch.library.impl(lib, "all_gather_copy_in", "XPU")
@torch.library.impl(lib, "all_gather_copy_in", "HPU")
@torch.library.impl(lib, "all_gather_copy_in", "CPU")
@torch.library.impl(lib, "all_gather_copy_in", "MTIA")
@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1")
def all_gather_copy_in_cuda(
all_gather_inputs: list[torch.Tensor],
inp_split_sizes: list[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
group_name: str,
allocate_memory_from_process_group: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather_output = allocate_memory(
all_gather_input_numel * world_size,
dtype=dtype,
device=device,
group=_resolve_process_group(group_name),
from_process_group=allocate_memory_from_process_group,
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
with torch.no_grad():
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
return all_gather_input, all_gather_output
lib.define(
"split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
)
@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
@torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
@torch.library.impl(lib, "split_with_sizes_copy", "XPU")
@torch.library.impl(lib, "split_with_sizes_copy", "HPU")
@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
@torch.library.impl(lib, "split_with_sizes_copy", "MTIA")
@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: list[int],
dim: int,
out: list[torch.Tensor],
) -> None:
torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
)
lib.define(
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
)
@torch.library.impl(lib, "chunk_cat", "Meta")
@torch.library.impl(lib, "chunk_cat", "CUDA")
@torch.library.impl(lib, "chunk_cat", "XPU")
@torch.library.impl(lib, "chunk_cat", "HPU")
@torch.library.impl(lib, "chunk_cat", "CPU")
@torch.library.impl(lib, "chunk_cat", "MTIA")
@torch.library.impl(lib, "chunk_cat", "PrivateUse1")
def chunk_cat(
tensors: list[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
) -> None:
torch._chunk_cat(tensors, dim, num_chunks, out=out)
@torch.no_grad()
def foreach_all_gather(
fsdp_params: list[FSDPParam],
group: dist.ProcessGroup,
async_op: bool,
all_gather_copy_in_stream: torch.Stream,
all_gather_stream: torch.Stream,
device: torch.device,
allocate_memory_from_process_group: bool = False,
) -> Optional[AllGatherResult]:
world_size, rank = group.size(), group.rank()
device_handle = _get_device_handle(device.type)
with device_handle.stream(all_gather_copy_in_stream):
param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
(
param_all_gather_input_dtypes,
param_all_gather_input_numels,
dtype,
) = _get_all_gather_input_metadatas(param_all_gather_inputs)
if dtype == torch.uint8:
all_gather_inputs = [
t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts
]
else:
all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)]
inp_split_sizes = [t.numel() for t in all_gather_inputs]
all_gather_input_numel = sum(inp_split_sizes)
all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
all_gather_inputs,
inp_split_sizes,
all_gather_input_numel,
world_size,
rank,
dtype,
device,
group.group_name,
allocate_memory_from_process_group,
)
del param_all_gather_inputs
all_gather_stream.wait_stream(all_gather_copy_in_stream)
with device_handle.stream(all_gather_stream):
all_gather_work = dist.all_gather_into_tensor(
output_tensor=all_gather_output,
input_tensor=all_gather_input,
group=group,
async_op=async_op,
)
all_gather_event = all_gather_stream.record_event()
return AllGatherResult(
all_gather_output,
all_gather_event,
all_gather_work,
param_all_gather_input_dtypes,
param_all_gather_input_numels,
inp_split_sizes,
)
@torch.no_grad()
def _get_param_all_gather_inputs(
fsdp_params: list[FSDPParam],
) -> list[list[torch.Tensor]]:
if compiled_autograd_enabled():
return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params]
# Intentionally try to run a fast-path that bypasses abstractions for the
# common FSDP case of bf16/fp32 mixed precision in order to use foreach
# copy for lower CPU overhead and more efficient copying in eager
def use_foreach_copy(fsdp_param: FSDPParam) -> bool:
return (
fsdp_param.param_dtype is not None
and not fsdp_param.offload_to_cpu
and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather")
)
param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params]
foreach_copy_indices: list[int] = []
foreach_copy_inputs: list[torch.Tensor] = []
foreach_copy_input_numels: list[int] = []
# 1st pass: for foreach-copy parameters, get inputs and metadata for the
# foreach copy, and for the others, actually get their all-gather inputs
for i, fsdp_param in enumerate(fsdp_params):
if use_foreach_copy(fsdp_param):
foreach_copy_indices.append(i)
all_gather_input = (
fsdp_param._sharded_param_data
if fsdp_param.sharded_state == ShardedState.SHARDED
else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data)
)
foreach_copy_inputs.append(all_gather_input)
foreach_copy_input_numels.append(all_gather_input.numel())
else:
param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
# 2nd pass: use foreach copy to compute the remaining all-gather inputs
if foreach_copy_inputs:
fsdp_param_0 = fsdp_params[foreach_copy_indices[0]]
param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device
flat_foreach_copy_input = torch.empty(
(sum(foreach_copy_input_numels),), device=device, dtype=param_dtype
)
splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels)
torch._foreach_copy_(splits, foreach_copy_inputs)
for i, split in zip(foreach_copy_indices, splits):
param_all_gather_inputs[i] = [split]
return param_all_gather_inputs
@torch.no_grad()
def foreach_all_gather_copy_out(
all_gather_result: AllGatherResult,
fsdp_params: list[FSDPParam],
group: dist.ProcessGroup,
) -> None:
(
all_gather_output,
all_gather_event,
all_gather_work,
param_all_gather_input_dtypes,
param_all_gather_input_numels,
all_gather_input_split_sizes,
) = all_gather_result
_dtype, device = all_gather_output.dtype, all_gather_output.device
device_handle = _get_device_handle(device.type)
if all_gather_event is not None: # sync op
device_handle.current_stream().wait_event(all_gather_event)
if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op
all_gather_work.wait()
world_size, device = group.size(), all_gather_output.device
split_with_sizes_out: list[torch.Tensor] = []
shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = []
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
):
# NOTE: Under compile, make sure we always recreate all_gather_outputs
# per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2].
force_recreate = compiled_autograd_enabled()
fsdp_param.init_all_gather_outputs(
all_gather_input_numels,
all_gather_input_dtypes,
world_size,
device,
force_recreate=force_recreate,
)
if not force_recreate:
fsdp_param.alloc_all_gather_outputs()
param_all_gather_outputs = fsdp_param.all_gather_outputs
if fsdp_param.fsdp_placement.dim != 0:
# Copy to a temporary and then chunk-cat into the final all-gather
# output tensors
param_all_gather_outputs = [
torch.empty_like(t) for t in param_all_gather_outputs
]
shard_i_copy_infos.append((fsdp_param, param_all_gather_outputs))
split_with_sizes_out.extend(param_all_gather_outputs)
all_gather_output = all_gather_output.view(world_size, -1)
if all_gather_output.dtype == torch.uint8:
out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out]
else:
out = [t.view(world_size, -1) for t in split_with_sizes_out]
# only avoid VC bump if we are not in inference mode
if torch._dynamo.is_compiling():
# For torch.compile, we turn off inference_mode for fake tensor
# propagation, and therefore graph break on is_inference. For `compile`,
# we don't care about VCs, so just skip the optimization.
non_inference_outs = []
else:
non_inference_outs = [o for o in out if not o.is_inference()]
if len(non_inference_outs) > 0:
with torch.autograd._unsafe_preserve_version_counter(tuple(non_inference_outs)):
torch.ops.fsdp.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
)
else:
torch.ops.fsdp.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
)
for fsdp_param, param_all_gather_outputs in shard_i_copy_infos:
# Chunk-cat from the temporary to the final all-gather output tensors
shard_dim = fsdp_param.fsdp_placement.dim
with torch.autograd._unsafe_preserve_version_counter(
tuple(fsdp_param.all_gather_outputs)
):
for param_all_gather_output, target_all_gather_output in zip(
param_all_gather_outputs, fsdp_param.all_gather_outputs
):
padded_sharded_size = (
fsdp_param.padded_sharded_param_size
if fsdp_param.sharded_state == ShardedState.SHARDED
else cast(
torch.Tensor, fsdp_param._sharded_post_forward_param_data
).size()
)
pre_param_size = list(padded_sharded_size)
pre_param_size[0] *= world_size
chunks = torch.chunk(
param_all_gather_output.view(pre_param_size), world_size, dim=0
)
post_param_size = list(padded_sharded_size)
post_param_size[shard_dim] *= world_size
cat_out = target_all_gather_output.view(post_param_size)
torch.cat(chunks, dim=shard_dim, out=cat_out)
@torch.no_grad()
def foreach_reduce(
fsdp_params: list[FSDPParam],
unsharded_grads: list[torch.Tensor],
reduce_scatter_group: dist.ProcessGroup,
reduce_scatter_stream: torch.Stream,
orig_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
device: torch.device,
gradient_divide_factor: Optional[float],
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
all_reduce_stream: torch.Stream,
all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
allocate_memory_from_process_group: bool = False,
force_sum_reduction_for_comms: bool = False,
) -> tuple[
torch.Tensor,
torch.Event,
torch.Event,
Optional[torch.Tensor],
Optional[torch.Event],
Optional[torch.Tensor],
]:
"""
``unsharded_grads`` owns the references to the gradients computed by
autograd, so clearing the list frees the gradients.
"""
grad_dtypes = {grad.dtype for grad in unsharded_grads}
if len(grad_dtypes) != 1:
# Check this at runtime since it could be a real runtime error if e.g.
# fp8 weights do not produce the correct higher precision gradients
_raise_assert_with_print(
f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}"
)
grad_dtype = unsharded_grads[0].dtype
reduce_dtype = reduce_dtype or grad_dtype
(predivide_factor, postdivide_factor, reduce_scatter_op, all_reduce_op) = (
_get_gradient_divide_factors(
reduce_scatter_group,
all_reduce_group,
reduce_dtype,
device.type,
gradient_divide_factor,
force_sum_reduction_for_comms,
)
)
world_size = reduce_scatter_group.size()
for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)):
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
continue
assert unsharded_grad.size(shard_dim) % world_size == 0, (
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
)
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
unsharded_grads[i] = torch.cat(chunks, dim=0)
padded_unsharded_sizes = tuple(
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
)
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
reduce_scatter_input = allocate_memory(
reduce_scatter_input_numel,
dtype=reduce_dtype,
device=device,
group=reduce_scatter_group,
from_process_group=allocate_memory_from_process_group,
)
device_handle = _get_device_handle(device.type)
foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size)
current_stream = device_handle.current_stream()
# Only after the copy-in finishes can we free the gradients
unsharded_grads.clear()
reduce_scatter_stream.wait_stream(current_stream)
all_reduce_input = None
all_reduce_event = None
with device_handle.stream(reduce_scatter_stream):
reduce_output = allocate_memory(
reduce_scatter_output_numel,
dtype=reduce_dtype,
device=device,
group=reduce_scatter_group,
from_process_group=allocate_memory_from_process_group,
)
_div_if_needed(reduce_scatter_input, predivide_factor)
dist.reduce_scatter_tensor(
output=reduce_output,
input=reduce_scatter_input,
group=reduce_scatter_group,
op=reduce_scatter_op,
)
reduce_scatter_event = reduce_scatter_stream.record_event()
post_reduce_stream = reduce_scatter_stream
if all_reduce_group is not None: # HSDP
# Accumulations must run in the reduce-scatter stream
if not all_reduce_grads:
if partial_reduce_output is not None:
partial_reduce_output += reduce_output
else:
partial_reduce_output = reduce_output
return (
reduce_scatter_input,
reduce_scatter_event,
post_reduce_stream.record_event(),
all_reduce_input,
all_reduce_event,
partial_reduce_output,
)
if partial_reduce_output is not None:
reduce_output += partial_reduce_output
post_reduce_stream = all_reduce_stream
all_reduce_stream.wait_stream(reduce_scatter_stream)
with device_handle.stream(all_reduce_stream):
dist.all_reduce(
reduce_output,
group=all_reduce_group,
op=all_reduce_op,
)
all_reduce_input = reduce_output
all_reduce_event = all_reduce_stream.record_event()
# -- END: ops in reduce_scatter stream
if all_reduce_hook is not None:
# Execute user-specified all reduce hook.
# If native HSDP is used, this is executed after the HSDP all reduce.
# If 1-d FSDP is used, this is executed post reduce-scatter.
post_reduce_stream = all_reduce_stream
all_reduce_stream.wait_stream(reduce_scatter_stream)
with device_handle.stream(all_reduce_stream):
all_reduce_hook(reduce_output)
# -- END: ops post reduce_scatter
with device_handle.stream(post_reduce_stream):
_div_if_needed(reduce_output, postdivide_factor)
reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype)
# View out and accumulate sharded gradients
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
for padded_unsharded_size, fsdp_param in zip(
padded_unsharded_sizes, fsdp_params
):
# Assume even sharding for Shard(i), i > 0; otherwise would require
# copy-out for contiguous strides
new_sharded_grad = torch.as_strided(
reduce_output,
size=fsdp_param.sharded_size,
stride=fsdp_param.contiguous_sharded_stride,
storage_offset=flat_grad_offset,
)
to_accumulate_grad = fsdp_param.sharded_param.grad is not None
if fsdp_param.offload_to_cpu:
# Only overlap the D2H copy (copying to pinned memory) if not
# accumulating gradients since the CPU add kernel depends on
# the copy result and we cannot run the add as a callback
non_blocking = fsdp_param.pin_memory and not to_accumulate_grad
# Since the GPU sharded gradient is allocated in the RS stream,
# we can free it here by not keeping a ref without waiting for
# the D2H copy since future RS-stream ops run after the copy
new_sharded_grad = new_sharded_grad.to(
torch.device("cpu"), non_blocking=non_blocking
)
if non_blocking:
# Record an event on which to block the CPU thread to
# ensure that the D2H copy finishes before the optimizer
fsdp_param.grad_offload_event = reduce_scatter_stream.record_event()
if to_accumulate_grad:
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
else:
new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(
new_sharded_grad
)
fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
if not compiled_autograd_enabled():
for hook in (
getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {})
or {}
).values():
hook(fsdp_param.sharded_param)
padded_sharded_numel = padded_unsharded_size.numel() // world_size
flat_grad_offset += padded_sharded_numel
post_reduce_event = post_reduce_stream.record_event()
# The RS output is allocated in the RS stream and used in the default
# stream (for optimizer). To ensure its memory is not reused for later
# RSs, we do not need extra synchronization since the sharded parameters
# hold refs through the end of backward.
return (
reduce_scatter_input,
reduce_scatter_event,
post_reduce_event,
all_reduce_input,
all_reduce_event,
None,
)
def foreach_reduce_scatter_copy_in(
unsharded_grads: list[torch.Tensor],
reduce_scatter_input: torch.Tensor,
world_size: int,
) -> None:
reduce_scatter_input = reduce_scatter_input.view(world_size, -1)
torch.ops.fsdp.chunk_cat(
unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input
)
def _get_all_gather_input_metadatas(
param_all_gather_inputs: list[list[torch.Tensor]],
) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]:
param_all_gather_input_dtypes: list[list[torch.dtype]] = []
param_all_gather_input_numels: list[list[int]] = []
all_gather_dtype = param_all_gather_inputs[0][0].dtype
for all_gather_inputs in param_all_gather_inputs:
input_dtypes: list[torch.dtype] = []
input_numels: list[int] = []
for all_gather_input in all_gather_inputs:
if all_gather_input.dtype != all_gather_dtype:
all_gather_dtype = torch.uint8
input_dtypes.append(all_gather_input.dtype)
input_numels.append(all_gather_input.numel())
param_all_gather_input_dtypes.append(input_dtypes)
param_all_gather_input_numels.append(input_numels)
return (
param_all_gather_input_dtypes,
param_all_gather_input_numels,
all_gather_dtype,
)
def _get_gradient_divide_factors(
reduce_scatter_group: dist.ProcessGroup,
all_reduce_group: Optional[dist.ProcessGroup],
reduce_dtype: torch.dtype,
device_type: str = "",
factor: Optional[float] = None,
force_sum_reduction_for_comms: bool = False,
) -> tuple[
Optional[float],
Optional[float],
Union[dist.ReduceOp, dist.ReduceOp.RedOpType],
Union[dist.ReduceOp, dist.ReduceOp.RedOpType],
]:
# MTIA appears to only support SUM reduction, hence we force it implicitly
if device_type == "mtia":
force_sum_reduction_for_comms = True
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
# use NCCL's built-in division to avoid separate div kernels
overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16)
data_parallel_size = reduce_scatter_group.size()
if all_reduce_group is not None:
data_parallel_size *= all_reduce_group.size()
if factor is None:
factor = float(data_parallel_size)
if not overflow_risk and not force_sum_reduction_for_comms:
if factor == data_parallel_size:
# Warning: NCCL ReduceOp.AVG may produce incorrect results with
# world size 1.
return None, None, ReduceOp.AVG, ReduceOp.AVG
else:
reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor)
return None, None, reduce_scatter_op, ReduceOp.SUM
pre_factor: Optional[float]
if overflow_risk:
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
# overflow/underflow. For N data parallel workers, each worker computes
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
pre_factor = 1
while factor % pre_factor == 0 and factor / pre_factor > pre_factor:
pre_factor *= 2
post_factor = factor / pre_factor
else:
# Prefer post-multiplying as it operates on less data and is thus faster
pre_factor, post_factor = None, factor
return pre_factor, post_factor, ReduceOp.SUM, ReduceOp.SUM
def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:
if div_factor is not None and div_factor != 1:
tensor.div_(div_factor)