Source code for syft.lib.python.slice

# stdlib
from typing import Any
from typing import Optional

# third party
from google.protobuf.reflection import GeneratedProtocolMessageType

# syft relative
from ... import deserialize
from ... import serialize
from ...core.common import UID
from ...core.common.serde.serializable import bind_protobuf
from ...proto.lib.python.slice_pb2 import Slice as Slice_PB
from .primitive_factory import PrimitiveFactory
from .primitive_interface import PyPrimitive
from .types import SyPrimitiveRet


@bind_protobuf
class Slice(PyPrimitive):
    def __init__(
        self,
        start: Any = None,
        stop: Optional[Any] = None,
        step: Optional[Any] = None,
        id: Optional[UID] = None,
    ):
        # first, second, third
        if stop is None and step is None:
            # slice treats 1 arg as stop not start
            stop = start
            start = None

        self.value = slice(start, stop, step)
        self._id: UID = id if id else UID()

    @property
    def id(self) -> UID:
        """We reveal PyPrimitive.id as a property to discourage users and
        developers of Syft from modifying .id attributes after an object
        has been initialized.

        :return: returns the unique id of the object
        :rtype: UID
        """
        return self._id

    def __eq__(self, other: Any) -> SyPrimitiveRet:
        res = self.value.__eq__(other)
        return PrimitiveFactory.generate_primitive(value=res)

    def __ge__(self, other: Any) -> SyPrimitiveRet:
        res = self.value.__ge__(other)  # type: ignore
        return PrimitiveFactory.generate_primitive(value=res)

    def __gt__(self, other: Any) -> SyPrimitiveRet:
        res = self.value.__gt__(other)  # type: ignore
        return PrimitiveFactory.generate_primitive(value=res)

    def __le__(self, other: Any) -> SyPrimitiveRet:
        res = self.value.__le__(other)  # type: ignore
        return PrimitiveFactory.generate_primitive(value=res)

    def __lt__(self, other: Any) -> SyPrimitiveRet:
        res = self.value.__lt__(other)  # type: ignore
        return PrimitiveFactory.generate_primitive(value=res)

    def __ne__(self, other: Any) -> SyPrimitiveRet:
        res = self.value.__ne__(other)
        return PrimitiveFactory.generate_primitive(value=res)

    def __str__(self) -> str:
        return self.value.__str__()

[docs] def indices(self, index: int) -> tuple: res = self.value.indices(index) return PrimitiveFactory.generate_primitive(value=res)
@property def start(self) -> Optional[int]: return self.value.start @property def step(self) -> Optional[int]: return self.value.step @property def stop(self) -> Optional[int]: return self.value.stop
[docs] def upcast(self) -> slice: return self.value
def _object2proto(self) -> Slice_PB: slice_pb = Slice_PB() if self.start: slice_pb.start = self.start slice_pb.has_start = True if self.stop: slice_pb.stop = self.stop slice_pb.has_stop = True if self.step: slice_pb.step = self.step slice_pb.has_step = True slice_pb.id.CopyFrom(serialize(obj=self._id)) return slice_pb @staticmethod def _proto2object(proto: Slice_PB) -> "Slice": id_: UID = deserialize(blob=proto.id) start = None stop = None step = None if proto.has_start: start = proto.start if proto.has_stop: stop = proto.stop if proto.has_step: step = proto.step return Slice( start=start, stop=stop, step=step, id=id_, )
[docs] @staticmethod def get_protobuf_schema() -> GeneratedProtocolMessageType: return Slice_PB