# stdlib
import re
from typing import Any
from typing import Dict
from typing import List
# third party
from packaging import version
import torch
# syft relative
from ...generate_wrapper import GenerateWrapper
from ...lib.util import full_name_with_name
from ...proto.lib.torch.returntypes_pb2 import ReturnTypes as ReturnTypes_PB
from ..torch.tensor_util import tensor_deserializer
from ..torch.tensor_util import tensor_serializer
# TODO: a better way. Loot at https://github.com/OpenMined/PySyft/issues/5249
module_type = type(torch)
torch.__dict__["return_types"] = module_type(name="return_types")
parent = torch.__dict__["return_types"]
[docs]def get_field_names(obj: Any) -> List[str]:
return re.findall("\n(.*)=", str(obj))
[docs]def get_supported_types_fields() -> Dict[type, List]:
supported_types = {}
# A = torch.tensor([[1.0, 1, 1], [2, 3, 4], [3, 5, 2], [4, 2, 5], [5, 4, 3]])
# B = torch.tensor([[-10.0, -3], [12, 14], [14, 12], [16, 16], [18, 16]])
x = torch.Tensor([[1, 2], [1, 2]])
s = torch.tensor(
[[-0.1000, 0.1000, 0.2000], [0.2000, 0.3000, 0.4000], [0.0000, -0.3000, 0.5000]]
)
torch_version_ge_1d5d0 = version.parse(
torch.__version__.split("+")[0]
) >= version.parse("1.5.0")
if torch_version_ge_1d5d0:
cummax = x.cummax(0)
supported_types[type(cummax)] = get_field_names(cummax)
if torch_version_ge_1d5d0:
cummin = x.cummin(0)
supported_types[type(cummin)] = get_field_names(cummin)
# eig = x.eig(True)
# supported_types[type(eig)] = get_field_names(eig)
kthvalue = x.kthvalue(1)
supported_types[type(kthvalue)] = get_field_names(kthvalue)
# lstsq = A.lstsq(B)
# supported_types[type(lstsq)] = get_field_names(lstsq)
slogdet = x.slogdet()
supported_types[type(slogdet)] = get_field_names(slogdet)
# qr = x.qr()
# supported_types[type(qr)] = get_field_names(qr)
mode = x.mode()
supported_types[type(mode)] = get_field_names(mode)
# solve = s.solve(s)
# supported_types[type(solve)] = get_field_names(solve)
sort = s.sort()
supported_types[type(sort)] = get_field_names(sort)
# symeig = s.symeig()
# supported_types[type(symeig)] = get_field_names(symeig)
topk = s.topk(1)
supported_types[type(topk)] = get_field_names(topk)
# triangular_solve = s.triangular_solve(s)
# supported_types[type(triangular_solve)] = get_field_names(triangular_solve)
svd = s.svd()
supported_types[type(svd)] = get_field_names(svd)
geqrf = s.geqrf()
supported_types[type(geqrf)] = get_field_names(geqrf)
median = s.median(0)
supported_types[type(median)] = get_field_names(median)
max_t = s.max(0)
supported_types[type(max_t)] = get_field_names(max_t)
min_t = s.min(0)
supported_types[type(min_t)] = get_field_names(min_t)
return supported_types
[docs]def wrap_type(typ: type, fields: List[str]) -> None:
def object2proto(obj: object) -> ReturnTypes_PB:
proto = ReturnTypes_PB()
obj_type = full_name_with_name(klass=obj._sy_serializable_wrapper_type) # type: ignore
proto.obj_type = obj_type
values = [getattr(obj, field, None) for field in fields]
proto.values.extend(list(map(lambda x: tensor_serializer(x), values)))
return proto
def proto2object(proto: ReturnTypes_PB) -> "typ": # type: ignore
values = [tensor_deserializer(x) for x in proto.values]
return typ(values)
GenerateWrapper(
wrapped_type=typ,
import_path=f"{typ.__module__}.{typ.__name__}",
protobuf_scheme=ReturnTypes_PB,
type_object2proto=object2proto,
type_proto2object=proto2object,
)
# TODO: a better way. Loot at https://github.com/OpenMined/PySyft/issues/5249
# add type to torch.return_types
parent.__dict__[typ.__name__] = typ
types_fields = get_supported_types_fields()
for typ, fields in types_fields.items():
wrap_type(typ, fields)