Source code for syft.lib.torch.device

# stdlib
from typing import Any
from typing import Optional

# third party
from torch import device

# syft relative
from ...generate_wrapper import GenerateWrapper
from ...proto.lib.torch.device_pb2 import Device as Device_PB

# use -2 to represent index=None
INDEX_NONE = -2


[docs]def object2proto(obj: device) -> "Device_PB": proto = Device_PB() proto.type = obj.type proto.index = INDEX_NONE if obj.index is None else obj.index return proto
[docs]def proto2object(proto: "Device_PB") -> Any: device_type = proto.type index: Optional[int] = None if proto.index == INDEX_NONE else proto.index obj = device(device_type, index) return obj
GenerateWrapper( wrapped_type=device, import_path="torch.device", protobuf_scheme=Device_PB, type_object2proto=object2proto, type_proto2object=proto2object, )