# stdlib
from collections import UserList
from typing import Any
from typing import List as ListType
from typing import Optional
from typing import Union
# third party
from google.protobuf.reflection import GeneratedProtocolMessageType
# syft relative
from ... import deserialize
from ... import serialize
from ...core.common import UID
from ...core.common.serde.serializable import bind_protobuf
from ...proto.lib.python.list_pb2 import List as List_PB
from .iterator import Iterator
from .primitive_factory import PrimitiveFactory
from .primitive_factory import isprimitive
from .primitive_interface import PyPrimitive
from .slice import Slice
from .types import SyPrimitiveRet
from .util import downcast
from .util import upcast
[docs]class ListIterator(Iterator):
pass
[docs]@bind_protobuf
class List(UserList, PyPrimitive):
__slots__ = ["_id", "_index"]
def __init__(self, value: Optional[Any] = None, id: Optional[UID] = None):
if value is None:
value = []
UserList.__init__(self, value)
self._id: UID = id if id else UID()
self._index = 0
@property
def id(self) -> UID:
"""We reveal PyPrimitive.id as a property to discourage users and
developers of Syft from modifying .id attributes after an object
has been initialized.
:return: returns the unique id of the object
:rtype: UID
"""
return self._id
[docs] def upcast(self) -> ListType:
# recursively upcast
new_list = []
# list comprehension doesn't work since it results in a
# [generator()] which is not equal to an empty list
for v in self:
new_list.append(upcast(v))
return new_list
def __gt__(self, other: Any) -> SyPrimitiveRet:
res = super().__gt__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __le__(self, other: Any) -> SyPrimitiveRet:
res = super().__le__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __lt__(self, other: Any) -> SyPrimitiveRet:
res = super().__lt__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __iadd__(self, other: Any) -> SyPrimitiveRet:
res = super().__iadd__(other)
return PrimitiveFactory.generate_primitive(value=res, id=self.id)
def __imul__(self, other: Any) -> SyPrimitiveRet:
res = super().__imul__(other)
return PrimitiveFactory.generate_primitive(value=res, id=self.id)
def __add__(self, other: Any) -> SyPrimitiveRet:
res = super().__add__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __contains__(self, other: Any) -> SyPrimitiveRet:
res = super().__contains__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __delitem__(self, other: Any) -> None:
res = super().__delitem__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __eq__(self, other: Any) -> SyPrimitiveRet:
res = super().__eq__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __ge__(self, other: Any) -> SyPrimitiveRet:
res = super().__ge__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __mul__(self, other: Any) -> SyPrimitiveRet:
res = super().__mul__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __ne__(self, other: Any) -> SyPrimitiveRet:
res = super().__ne__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __sizeof__(self) -> SyPrimitiveRet:
res = super().__sizeof__()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def sort(self, *args: Any, **kwargs: Any) -> None:
res = super().sort(*args, **kwargs)
return PrimitiveFactory.generate_primitive(value=res)
def __len__(self) -> Any:
res = super().__len__()
return PrimitiveFactory.generate_primitive(value=res)
def __getitem__(self, key: Union[int, str, slice, Slice]) -> Any:
if isinstance(key, Slice):
key = key.upcast()
res = super().__getitem__(key) # type: ignore
# we might be holding a primitive value, but generate_primitive
# doesn't handle non primitives so we should check
if isprimitive(value=res):
return PrimitiveFactory.generate_primitive(value=res)
return res
def __iter__(self, max_len: Optional[int] = None) -> ListIterator:
return ListIterator(self, max_len=max_len)
[docs] def copy(self) -> "List":
res = super().copy()
res._id = UID()
return res
[docs] def append(self, item: Any) -> None:
res = super().append(item)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def count(self, other: Any) -> SyPrimitiveRet:
res = super().count(other)
return PrimitiveFactory.generate_primitive(value=res)
def _object2proto(self) -> List_PB:
id_ = serialize(obj=self.id)
downcasted = [downcast(value=element) for element in self.data]
data = [serialize(obj=element, to_bytes=True) for element in downcasted]
return List_PB(id=id_, data=data)
@staticmethod
def _proto2object(proto: List_PB) -> "List":
id_: UID = deserialize(blob=proto.id)
value = []
# list comprehension doesn't work since it results in a
# [generator()] which is not equal to an empty list
for element in proto.data:
value.append(upcast(deserialize(blob=element, from_bytes=True)))
new_list = List(value=value)
new_list._id = id_
return new_list
[docs] @staticmethod
def get_protobuf_schema() -> GeneratedProtocolMessageType:
return List_PB