Source code for syft.lib.torch.parameter

# stdlib

# third party
import torch as th
from torch.nn import Parameter

# syft relative
from ...generate_wrapper import GenerateWrapper
from ...lib.torch.tensor_util import tensor_deserializer
from ...lib.torch.tensor_util import tensor_serializer
from ...proto.lib.torch.parameter_pb2 import ParameterProto as Parameter_PB

torch_tensor = th.tensor([1.0, 2.0, 3.0])
torch_parameter_type = type(th.nn.parameter.Parameter(torch_tensor))


[docs]def object2proto(obj: object) -> Parameter_PB: proto = Parameter_PB() tensor_data = getattr(obj, "data", None) if tensor_data is not None: proto.tensor = tensor_serializer(tensor_data) proto.requires_grad = getattr(obj, "requires_grad", False) grad = getattr(obj, "grad", None) if grad is not None: proto.grad = tensor_serializer(grad) # opacus monkey patches this onto the Parameter class grad_sample = getattr(obj, "grad_sample", None) if grad_sample is not None: proto.grad_sample = tensor_serializer(grad_sample) return proto
[docs]def proto2object(proto: Parameter_PB) -> Parameter: data = tensor_deserializer(proto.tensor) param = Parameter(data, requires_grad=proto.requires_grad) if proto.HasField("grad"): param.grad = tensor_deserializer(proto.grad) # opacus monkey patches this onto the Parameter class if proto.HasField("grad_sample"): param.grad_sample = tensor_deserializer(proto.grad_sample) return param
GenerateWrapper( wrapped_type=torch_parameter_type, import_path="torch.nn.parameter.Parameter", protobuf_scheme=Parameter_PB, type_object2proto=object2proto, type_proto2object=proto2object, )