# stdlib
from collections import UserString
from typing import Any
from typing import Mapping
from typing import Optional
from typing import Union
# third party
from google.protobuf.reflection import GeneratedProtocolMessageType
# syft relative
from ... import deserialize
from ... import serialize
from ...core.common import UID
from ...core.common.serde.serializable import bind_protobuf
from ...proto.lib.python.string_pb2 import String as String_PB
from .int import Int
from .primitive_factory import PrimitiveFactory
from .primitive_interface import PyPrimitive
from .slice import Slice
from .types import SyPrimitiveRet
@bind_protobuf
class String(UserString, PyPrimitive):
def __init__(self, value: Any = None, id: Optional[UID] = None):
if value is None:
value = ""
UserString.__init__(self, value)
self._id: UID = id if id else UID()
[docs] def upcast(self) -> str:
return str(self)
def __add__(self, other: Any) -> SyPrimitiveRet:
res = super().__add__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __eq__(self, other: Any) -> SyPrimitiveRet:
res = super().__eq__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __float__(self) -> SyPrimitiveRet:
res = super().__float__()
return PrimitiveFactory.generate_primitive(value=res)
def __ge__(self, other: Any) -> SyPrimitiveRet:
res = super().__ge__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __getitem__(self, key: Union[int, slice, Slice]) -> Any:
if isinstance(key, Slice):
key = key.upcast()
res = super().__getitem__(key)
return PrimitiveFactory.generate_primitive(value=res)
def __gt__(self, other: Any) -> SyPrimitiveRet:
res = super().__gt__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __hash__(self) -> SyPrimitiveRet:
res = super().__hash__()
return PrimitiveFactory.generate_primitive(value=res)
def __int__(self) -> SyPrimitiveRet:
res = super().__int__()
return PrimitiveFactory.generate_primitive(value=res)
def __iter__(self) -> SyPrimitiveRet:
# TODO fix this
res = super().__iter__()
return PrimitiveFactory.generate_primitive(value=res)
def __le__(self, other: Any) -> SyPrimitiveRet:
res = super().__le__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __len__(self) -> SyPrimitiveRet:
res = super().__len__()
return PrimitiveFactory.generate_primitive(value=res)
def __lt__(self, other: Any) -> SyPrimitiveRet:
res = super().__lt__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __mod__(self, *args: Any) -> SyPrimitiveRet:
res = super().__mod__(
*[str(arg) if isinstance(arg, String) else arg for arg in args]
)
return PrimitiveFactory.generate_primitive(value=res)
def __mul__(self, other: Any) -> SyPrimitiveRet:
res = super().__mul__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __ne__(self, other: Any) -> SyPrimitiveRet:
res = super().__ne__(other)
return PrimitiveFactory.generate_primitive(value=res)
def __reversed__(self) -> Any:
# returns <class 'reversed'>
return super().__reversed__()
def __sizeof__(self) -> SyPrimitiveRet:
res = super().__sizeof__()
return PrimitiveFactory.generate_primitive(value=res)
def __str__(self) -> str:
return super().__str__()
[docs] def capitalize(self) -> SyPrimitiveRet:
res = super().capitalize()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def casefold(self) -> SyPrimitiveRet:
res = super().casefold()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def center(self, width: Union[int, Int], *args: Any) -> SyPrimitiveRet:
if args:
_args_0 = str(args[0]) if isinstance(args[0], String) else args[0]
res = super().center(width, _args_0, *args[1:])
else:
res = super().center(width)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def count(
self, sub: Any, start: Optional[int] = None, end: Optional[int] = None
) -> SyPrimitiveRet:
res = super().count(sub, start, end) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def encode(
self, encoding: Optional[str] = None, errors: Optional[str] = None
) -> SyPrimitiveRet:
res = super().encode(encoding, errors)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def endswith(
self,
suffix: Union[str, "String", tuple],
start: Optional[int] = None,
end: Optional[int] = None,
) -> SyPrimitiveRet:
suffix = str(suffix) if isinstance(suffix, String) else suffix
_suffix = (
tuple(str(elem) if isinstance(elem, String) else elem for elem in suffix)
if isinstance(suffix, tuple)
else suffix
)
res = super().endswith(_suffix, start, end) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def expandtabs(self, tabsize: int = 8) -> SyPrimitiveRet:
res = super().expandtabs(tabsize)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def find(
self, sub: Any, start: Optional[int] = 0, end: Optional[int] = None
) -> SyPrimitiveRet:
if end is None:
end = super().__len__()
res = super().find(sub, start, end) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def index(
self,
sub: Union[str, "String"],
start: Optional[int] = 0,
end: Optional[int] = None,
) -> SyPrimitiveRet:
if end is None:
end = super().__len__()
res = super().index(str(sub), start, end) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isalnum(self) -> SyPrimitiveRet:
res = super().isalnum()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isascii(self) -> SyPrimitiveRet:
res = super().isascii() # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isalpha(self) -> SyPrimitiveRet:
res = super().isalpha()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isdecimal(self) -> SyPrimitiveRet:
res = super().isdecimal()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isdigit(self) -> SyPrimitiveRet:
res = super().isdigit()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isidentifier(self) -> SyPrimitiveRet:
res = super().isidentifier()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def islower(self) -> SyPrimitiveRet:
res = super().islower()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isnumeric(self) -> SyPrimitiveRet:
res = super().isnumeric()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isprintable(self) -> SyPrimitiveRet:
res = super().isprintable()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isspace(self) -> SyPrimitiveRet:
res = super().isspace()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def istitle(self) -> SyPrimitiveRet:
res = super().istitle()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def isupper(self) -> SyPrimitiveRet:
res = super().isupper()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def join(self, seq: Any) -> SyPrimitiveRet:
res = super().join(
[str(elem) if isinstance(elem, String) else elem for elem in seq]
)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def ljust(self, width: Union[int], *args: Any) -> SyPrimitiveRet:
if args:
_args_0 = str(args[0]) if isinstance(args[0], String) else args[0]
res = super().ljust(width, _args_0, *args[1:])
else:
res = super().ljust(width)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def lower(self) -> SyPrimitiveRet:
res = super().lower()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def lstrip(self, chars: Optional[Union[str, "String"]] = None) -> SyPrimitiveRet:
chars = str(chars) if isinstance(chars, String) else chars
res = super().lstrip(chars)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def partition(self, sep: Optional[Union[str, "String"]] = " ") -> SyPrimitiveRet:
sep = str(sep) if isinstance(sep, String) else sep
res = super().partition(sep) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def replace(
self,
oldvalue: Union[str, UserString],
newvalue: Union[str, UserString],
count: Optional[int] = -1,
) -> SyPrimitiveRet:
res = super().replace(str(oldvalue), str(newvalue), count) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def rfind(
self,
sub: Union[str, UserString],
start: Optional[int] = 0,
end: Optional[int] = None,
) -> SyPrimitiveRet:
sub = str(sub) if isinstance(sub, UserString) else sub
res = super().rfind(sub, start, end) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def rindex(
self,
sub: Union[str, UserString],
start: Optional[int] = 0,
end: Optional[int] = None,
) -> SyPrimitiveRet:
sub = str(sub) if isinstance(sub, String) else sub
res = super().rindex(sub, start, end) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def rjust(self, width: int, *args: Any) -> SyPrimitiveRet:
if args:
_args_0 = str(args[0]) if isinstance(args[0], String) else args[0]
res = super().rjust(width, _args_0, *args[1:])
else:
res = super().rjust(width)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def rpartition(self, sep: Optional[Union[str, "String"]] = " ") -> SyPrimitiveRet:
sep = str(sep) if isinstance(sep, String) else sep
res = super().rpartition(sep) # type: ignore
return PrimitiveFactory.generate_primitive(value=res)
[docs] def rsplit(
self, sep: Optional[Union[str, "String"]] = None, maxsplit: int = -1
) -> SyPrimitiveRet:
sep = str(sep) if isinstance(sep, String) else sep
res = super().rsplit(sep=sep, maxsplit=maxsplit)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def rstrip(self, chars: Optional[Union[str, "String"]] = None) -> SyPrimitiveRet:
chars = str(chars) if isinstance(chars, String) else chars
res = super().rstrip(chars)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def split(
self, sep: Optional[Union[str, "String"]] = None, maxsplit: int = -1
) -> SyPrimitiveRet:
sep = str(sep) if isinstance(sep, String) else sep
res = super().split(sep=sep, maxsplit=maxsplit)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def splitlines(self, keepends: bool = False) -> SyPrimitiveRet:
res = super().splitlines(keepends)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def startswith(
self,
suffix: Union[str, UserString, tuple],
start: int = 0,
end: Optional[int] = None,
) -> SyPrimitiveRet:
suffix = str(suffix) if isinstance(suffix, UserString) else suffix
suffix = (
tuple(
str(elem) if isinstance(elem, UserString) else elem for elem in suffix
)
if isinstance(suffix, tuple)
else suffix
)
end = end if end else len(self)
res = super().startswith(suffix, start, end)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def strip(self, chars: Optional[str] = None) -> SyPrimitiveRet:
chars = str(chars) if isinstance(chars, String) else chars # type: ignore
res = super().strip(chars)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def swapcase(self) -> SyPrimitiveRet:
res = super().swapcase()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def title(self) -> SyPrimitiveRet:
res = super().title()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def translate(self, *args: Any) -> SyPrimitiveRet:
res = super().translate(*args)
return PrimitiveFactory.generate_primitive(value=res)
[docs] def upper(self) -> SyPrimitiveRet:
res = super().upper()
return PrimitiveFactory.generate_primitive(value=res)
[docs] def zfill(self, width: Union[int, Int]) -> SyPrimitiveRet:
res = super().zfill(width)
return PrimitiveFactory.generate_primitive(value=res)
def __contains__(self, val: object) -> SyPrimitiveRet:
res = super().__contains__(val)
return PrimitiveFactory.generate_primitive(value=res)
@property
def id(self) -> UID:
"""We reveal PyPrimitive.id as a property to discourage users and
developers of Syft from modifying .id attributes after an object
has been initialized.
:return: returns the unique id of the object
:rtype: UID
"""
return self._id
def _object2proto(self) -> String_PB:
return String_PB(data=self.data, id=serialize(obj=self.id))
@staticmethod
def _proto2object(proto: String_PB) -> "String":
str_id: UID = deserialize(blob=proto.id)
return String(value=proto.data, id=str_id)
[docs] @staticmethod
def get_protobuf_schema() -> GeneratedProtocolMessageType:
return String_PB
# fixes __rmod__ in python <= 3.7
# https://github.com/python/cpython/commit/7abf8c60819d5749e6225b371df51a9c5f1ea8e9
def __rmod__(self, template: Union[PyPrimitive, str]) -> PyPrimitive:
return self.__class__(str(template) % self)