# mypy: allow-untyped-defs r"""Functional interface.""" import math from typing import List from torch import Tensor from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401 from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401 from .adam import adam # type: ignore[attr-defined] # noqa: F401 from .adamax import adamax # type: ignore[attr-defined] # noqa: F401 from .adamw import adamw # type: ignore[attr-defined] # noqa: F401 from .asgd import asgd # type: ignore[attr-defined] # noqa: F401 from .nadam import nadam # type: ignore[attr-defined] # noqa: F401 from .radam import radam # type: ignore[attr-defined] # noqa: F401 from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401 from .rprop import rprop # type: ignore[attr-defined] # noqa: F401 from .sgd import sgd # type: ignore[attr-defined] # noqa: F401 # TODO: use foreach API in optim._functional to do all the computation def sparse_adam( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], state_steps: List[int], *, eps: float, beta1: float, beta2: float, lr: float, maximize: bool, ): r"""Functional API that performs Sparse Adam algorithm computation. See :class:`~torch.optim.SparseAdam` for details. """ for i, param in enumerate(params): grad = grads[i] grad = grad if not maximize else -grad grad = grad.coalesce() # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() if grad_values.numel() == 0: # Skip update for empty grad continue size = grad.size() exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] def make_sparse(values): constructor = grad.new if grad_indices.dim() == 0 or values.dim() == 0: return constructor().resize_as_(grad) return constructor(grad_indices, values, size) # Decay the first and second moment running average coefficient # old <- b * old + (1 - b) * new # <==> old += (1 - b) * (new - old) old_exp_avg_values = exp_avg.sparse_mask(grad)._values() exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) exp_avg.add_(make_sparse(exp_avg_update_values)) old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() exp_avg_sq_update_values = ( grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) ) exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) # Dense addition again is intended, avoiding another sparse_mask numer = exp_avg_update_values.add_(old_exp_avg_values) exp_avg_sq_update_values.add_(old_exp_avg_sq_values) denom = exp_avg_sq_update_values.sqrt_().add_(eps) del exp_avg_update_values, exp_avg_sq_update_values bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step step_size = lr * math.sqrt(bias_correction2) / bias_correction1 param.add_(make_sparse(-step_size * numer.div_(denom)))