Source code for syft.lib.python.dict

# stdlib
from collections import UserDict
from collections.abc import ItemsView
from collections.abc import KeysView
from collections.abc import ValuesView
from typing import Any
from typing import Dict as TypeDict
from typing import Iterable
from typing import Optional
from typing import Union

# third party
from google.protobuf.reflection import GeneratedProtocolMessageType

# syft relative
from ... import deserialize
from ... import serialize
from ...core.common import UID
from ...core.common.serde.serializable import bind_protobuf
from ...logger import traceback_and_raise
from ...logger import warning
from ...proto.lib.python.dict_pb2 import Dict as Dict_PB
from .iterator import Iterator
from .primitive_factory import PrimitiveFactory
from .primitive_factory import isprimitive
from .primitive_interface import PyPrimitive
from .types import SyPrimitiveRet
from .util import downcast
from .util import upcast


@bind_protobuf
class Dict(UserDict, PyPrimitive):
    # the incoming types to UserDict __init__ are overloaded and weird
    # see https://github.com/python/cpython/blob/master/Lib/collections/__init__.py
    # this is the version from python 3.7 because we need to support 3.7
    # python 3.8 signature includes a new PEP 570 (args, /, kwargs) syntax:
    # https://www.python.org/dev/peps/pep-0570/

    def __init__(*args: Any, **kwargs: Any) -> None:
        if not args:
            traceback_and_raise(
                TypeError("descriptor '__init__' of 'Dict' object " "needs an argument")
            )
        self, *args = args  # type: ignore
        if len(args) > 1:
            traceback_and_raise(
                TypeError(f"expected at most 1 arguments, got {len(args)}")
            )
        if args:
            args_dict = args[0]
        elif "dict" in kwargs:
            args_dict = kwargs.pop("dict")

            warning(
                "Passing 'dict' as keyword argument is deprecated",
                DeprecationWarning,
                stacklevel=2,
            )
        else:
            args_dict = None

        self.data = {}
        if args_dict is not None:
            self.update(args_dict)
        if kwargs:
            self.update(kwargs)

        # We cant add UID from kwargs or it could easily be overwritten by the dict
        # that is being passed in for __init__
        # If you want to update it use the _id setter after creation.
        self._id = UID()

    @property
    def id(self) -> UID:
        """We reveal PyPrimitive.id as a property to discourage users and
        developers of Syft from modifying .id attributes after an object
        has been initialized.

        :return: returns the unique id of the object
        :rtype: UID
        """
        return self._id

[docs] def upcast(self) -> TypeDict: # recursively upcast return {k: upcast(v) for k, v in self.items()}
def __contains__(self, other: Any) -> SyPrimitiveRet: res = super().__contains__(other) return PrimitiveFactory.generate_primitive(value=res) def __eq__(self, other: Any) -> SyPrimitiveRet: res = super().__eq__(other) return PrimitiveFactory.generate_primitive(value=res) def __format__(self, format_spec: str) -> str: # python complains if the return value is not str res = super().__format__(format_spec) return str(res) def __ge__(self, other: Any) -> SyPrimitiveRet: res = super().__ge__(other) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __getitem__(self, key: Any) -> Union[SyPrimitiveRet, Any]: res = super().__getitem__(key) if isprimitive(value=res): return PrimitiveFactory.generate_primitive(value=res) else: # we can have torch.Tensor and other types return res def __gt__(self, other: Any) -> SyPrimitiveRet: res = super().__gt__(other) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __hash__(self) -> SyPrimitiveRet: res = super().__hash__() return PrimitiveFactory.generate_primitive(value=res) def __iter__(self, max_len: Optional[int] = None) -> Iterator: return Iterator(super().__iter__(), max_len=max_len) def __le__(self, other: Any) -> SyPrimitiveRet: res = super().__le__(other) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __len__(self) -> SyPrimitiveRet: res = super().__len__() return PrimitiveFactory.generate_primitive(value=res) def __lt__(self, other: Any) -> SyPrimitiveRet: res = super().__lt__(other) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __ne__(self, other: Any) -> SyPrimitiveRet: res = super().__ne__(other) return PrimitiveFactory.generate_primitive(value=res) def __sizeof__(self) -> SyPrimitiveRet: res = super().__sizeof__() return PrimitiveFactory.generate_primitive(value=res)
[docs] def copy(self) -> SyPrimitiveRet: res = super().copy() return PrimitiveFactory.generate_primitive(value=res)
[docs] @classmethod def fromkeys( cls, iterable: Iterable, value: Optional[Any] = None ) -> SyPrimitiveRet: res = super().fromkeys(iterable, value) return PrimitiveFactory.generate_primitive(value=res)
[docs] def dict_get(self, key: Any, default: Any = None) -> Any: res = super().get(key, default) if isprimitive(value=res): return PrimitiveFactory.generate_primitive(value=res) else: # we can have torch.Tensor and other types return res
[docs] def items(self, max_len: Optional[int] = None) -> Iterator: # type: ignore return Iterator(ItemsView(self), max_len=max_len)
[docs] def keys(self, max_len: Optional[int] = None) -> Iterator: # type: ignore return Iterator(KeysView(self), max_len=max_len)
[docs] def values(self, *args: Any, max_len: Optional[int] = None) -> Iterator: # type: ignore # this is what the super type does and there is a test in dict_test.py # test_values which checks for this so we could disable the test or # keep this workaround if len(args) > 0: traceback_and_raise( TypeError("values() takes 1 positional argument but 2 were given") ) return Iterator(ValuesView(self), max_len=max_len)
[docs] def pop(self, key: Any, *args: Any) -> SyPrimitiveRet: res = super().pop(key, *args) return PrimitiveFactory.generate_primitive(value=res)
[docs] def popitem(self) -> SyPrimitiveRet: res = self.data.popitem() return PrimitiveFactory.generate_primitive(value=res)
[docs] def setdefault(self, key: Any, default: Any = None) -> SyPrimitiveRet: res = PrimitiveFactory.generate_primitive(value=default) res = super().setdefault(key, res) return res
[docs] def clear(self) -> None: # we get the None return and create a SyNone # this is to make sure someone doesn't rewrite the method to return nothing return PrimitiveFactory.generate_primitive(value=super().clear())
def _object2proto(self) -> Dict_PB: id_ = serialize(obj=self.id) keys = [ serialize(obj=downcast(value=element), to_bytes=True) for element in self.data.keys() ] values = [ serialize(obj=downcast(value=element), to_bytes=True) for element in self.data.values() ] return Dict_PB(id=id_, keys=keys, values=values) @staticmethod def _proto2object(proto: Dict_PB) -> "Dict": id_: UID = deserialize(blob=proto.id) values = [ upcast(value=deserialize(blob=element, from_bytes=True)) for element in proto.values ] keys = [ upcast(value=deserialize(blob=element, from_bytes=True)) for element in proto.keys ] new_dict = Dict(dict(zip(keys, values))) new_dict._id = id_ return new_dict
[docs] @staticmethod def get_protobuf_schema() -> GeneratedProtocolMessageType: return Dict_PB