# stdlib
from collections import OrderedDict as PyOrderedDict
from collections.abc import ItemsView
from collections.abc import KeysView
from collections.abc import ValuesView
from typing import Any
from typing import Optional
# third party
from google.protobuf.reflection import GeneratedProtocolMessageType
# syft relative
from .... import deserialize
from .... import serialize
from ....core.common.serde.serializable import bind_protobuf
from ....core.common.uid import UID
from ....logger import traceback_and_raise
from ....proto.lib.python.collections.ordered_dict_pb2 import (
OrderedDict as OrderedDict_PB,
)
from ..iterator import Iterator
from ..primitive_factory import PrimitiveFactory
from ..primitive_factory import isprimitive
from ..primitive_interface import PyPrimitive
from ..types import SyPrimitiveRet
from ..util import downcast
from ..util import upcast
[docs]@bind_protobuf
class OrderedDict(PyOrderedDict, PyPrimitive):
def __init__(self, *args: Any, _id: UID = UID(), **kwds: Any):
super().__init__(*args, **kwds)
self._id = _id
@property
def id(self) -> UID:
"""We reveal PyPrimitive.id as a property to discourage users and
developers of Syft from modifying .id attributes after an object
has been initialized.
:return: returns the unique id of the object
:rtype: UID
"""
return self._id
def __contains__(self, other: Any) -> SyPrimitiveRet:
res = super().__contains__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __delitem__(self, other: Any) -> None:
res = super().__delitem__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __eq__(self, other: Any) -> SyPrimitiveRet:
res = super().__eq__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __getitem__(self, other: Any) -> SyPrimitiveRet:
res = super().__getitem__(other)
if isprimitive(value=res):
return PrimitiveFactory.generate_primitive(value=res)
else:
# we can have torch.Tensor and other types
return res
def __iter__(self, max_len: Optional[int] = None) -> Iterator:
return Iterator(super().__iter__(), max_len=max_len)
def __len__(self) -> SyPrimitiveRet:
res = super().__len__()
return PrimitiveFactory.generate_primitive(value=res)
def __ne__(self, other: Any) -> SyPrimitiveRet:
res = super().__ne__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __reversed__(self) -> Any:
# returns <class 'odict_iterator'>
return super().__reversed__()
def __setitem__(self, key: Any, value: Any) -> None:
res = super().__setitem__(key, value)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def clear(self) -> None:
res = super().clear()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def copy(self) -> SyPrimitiveRet:
res = super().copy()
return PrimitiveFactory.generate_primitive(value=res)
[docs] @classmethod
def FromKeys(cls, iterable: Any, value: Any = None) -> SyPrimitiveRet:
res = cls(PyOrderedDict.fromkeys(iterable, value))
return PrimitiveFactory.generate_primitive(value=res)
[docs] def fromkeys( # type: ignore
self, iterable: Any, value: Optional[object] = None
) -> SyPrimitiveRet:
res = super().fromkeys(iterable, value)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def dict_get(self, other: Any) -> Any:
res = super().get(other)
if isprimitive(value=res):
return PrimitiveFactory.generate_primitive(value=res)
else:
# we can have torch.Tensor and other types
return res
[docs] def items(self, max_len: Optional[int] = None) -> Iterator: # type: ignore
return Iterator(ItemsView(self), max_len=max_len)
[docs] def keys(self, max_len: Optional[int] = None) -> Iterator: # type: ignore
return Iterator(KeysView(self), max_len=max_len)
[docs] def move_to_end(self, other: Any, last: Any = True) -> Any:
res = super().move_to_end(other, last)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def pop(self, *args: Any, **kwargs: Any) -> SyPrimitiveRet:
res = super().pop(*args, **kwargs)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def popitem(self, last: Any = True) -> SyPrimitiveRet:
res = super().popitem(last)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def setdefault(self, key: Any, default: Optional[object] = None) -> SyPrimitiveRet:
res = super().setdefault(key, default)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def update(self, *args, **kwds: Any) -> SyPrimitiveRet: # type: ignore
res = super().update(*args, **kwds)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def values(self, *args: Any, max_len: Optional[int] = None) -> Iterator: # type: ignore
# this is what the super type does and there is a test in dict_test.py
# test_values which checks for this so we could disable the test or
# keep this workaround
if len(args) > 0:
traceback_and_raise(
TypeError("values() takes 1 positional argument but 2 were given")
)
return Iterator(ValuesView(self), max_len=max_len)
def _object2proto(self) -> OrderedDict_PB:
id_ = serialize(obj=self.id)
# serialize to bytes so that we can avoid using StorableObject
# otherwise we get recursion where the permissions of StorableObject
# themselves utilise Dict
keys = [
serialize(obj=downcast(value=element), to_bytes=True)
for element in self.keys()
]
# serialize to bytes so that we can avoid using StorableObject
# otherwise we get recursion where the permissions of StorableObject
# themselves utilise Dict
values = [
serialize(obj=downcast(value=element), to_bytes=True)
for element in self.values()
]
return OrderedDict_PB(id=id_, keys=keys, values=values)
@staticmethod
def _proto2object(proto: OrderedDict_PB) -> "OrderedDict":
id_: UID = deserialize(blob=proto.id)
# deserialize from bytes so that we can avoid using StorableObject
# otherwise we get recursion where the permissions of StorableObject
# themselves utilise OrederedDict
values = [
deserialize(blob=upcast(value=element), from_bytes=True)
for element in proto.values
]
# deserialize from bytes so that we can avoid using StorableObject
# otherwise we get recursion where the permissions of StorableObject
# themselves utilise OrderedDict
keys = [
deserialize(blob=upcast(value=element), from_bytes=True)
for element in proto.keys
]
new_dict = OrderedDict(dict(zip(keys, values)))
new_dict._id = id_
return new_dict
[docs] @staticmethod
def get_protobuf_schema() -> GeneratedProtocolMessageType:
return OrderedDict_PB
[docs] def upcast(self) -> PyOrderedDict:
# recursively upcast
return OrderedDict((k, upcast(v)) for k, v in self.items())