"""This module contains `Class` attribute,an AST node representing a class."""
# stdlib
from enum import Enum
from enum import EnumMeta
import inspect
from types import ModuleType
from typing import Any
from typing import Callable as CallableT
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import warnings
# syft relative
from .. import ast
from .. import lib
from ..ast.callable import Callable
from ..core.common.group import VERIFYALL
from ..core.common.uid import UID
from ..core.node.common.action.get_or_set_property_action import GetOrSetPropertyAction
from ..core.node.common.action.get_or_set_property_action import PropertyActions
from ..core.node.common.action.run_class_method_action import RunClassMethodAction
from ..core.node.common.action.save_object_action import SaveObjectAction
from ..core.node.common.service.resolve_pointer_type_service import (
ResolvePointerTypeMessage,
)
from ..core.pointer.pointer import Pointer
from ..core.store.storeable_object import StorableObject
from ..logger import critical
from ..logger import traceback_and_raise
from ..logger import warning
from ..util import aggressive_set_attr
from ..util import inherit_tags
def _resolve_pointer_type(self: Pointer) -> Pointer:
"""Resolve pointer of the object.
Creates a request on a pointer to validate and regenerate the current pointer type. This method
is useful when deadling with AnyPointer or Union<types>Pointers, to retrieve the real pointer.
The existing pointer will be deleted and a new one will be generated. The remote data won't
be touched.
Args:
self: The pointer which will be validated.
Returns:
The new pointer, validated from the remote object.
"""
# id_at_location has to be preserved
id_at_location = getattr(self, "id_at_location", None)
if None:
traceback_and_raise(
ValueError("Can't resolve a pointer that has no underlying object.")
)
cmd = ResolvePointerTypeMessage(
id_at_location=id_at_location,
address=self.client.address,
reply_to=self.client.address,
)
# the path to the underlying type. It has to live in the AST
real_type_path = self.client.send_immediate_msg_with_reply(msg=cmd).type_path
new_pointer = self.client.lib_ast.query(real_type_path).pointer_type(
client=self.client, id_at_location=id_at_location
)
# we disable the garbage collection message and then we delete the existing message.
self.gc_enabled = False
del self
return new_pointer
[docs]def get_run_class_method(attr_path_and_name: str) -> CallableT:
"""Create a function for class method in `attr_path_and_name` for remote execution.
Args:
attr_path_and_name: The path of the class method.
Returns:
Function for the class method.
Note:
It might seem hugely un-necessary to have these methods nested in this way.
However, it has to do with ensuring that the scope of `attr_path_and_name` is local
and not global.
If we do not put a `get_run_class_method` around `run_class_method` then
each `run_class_method` will end up referencing the same `attr_path_and_name` variable
and all methods will actually end up calling the same method.
If, instead, we return the function object itself then it includes
the current `attr_path_and_name` as an internal variable and when we call `get_run_class_method`
multiple times it returns genuinely different methods each time with a different
internal `attr_path_and_name` variable.
"""
def run_class_method(
__self: Any,
*args: Tuple[Any, ...],
**kwargs: Any,
) -> object:
"""Run remote class method and get pointer to returned object.
Args:
*args: Args list of class method.
**kwargs: Keyword args of class method.
Returns:
Pointer to object returned by class method.
"""
# we want to get the return type which matches the attr_path_and_name
# so we ask lib_ast for the return type name that matches out
# attr_path_and_name and then use that to get the actual pointer klass
# then set the result to that pointer klass
return_type_name = __self.client.lib_ast.query(
attr_path_and_name
).return_type_name
resolved_pointer_type = __self.client.lib_ast.query(return_type_name)
result = resolved_pointer_type.pointer_type(client=__self.client)
# QUESTION can the id_at_location be None?
result_id_at_location = getattr(result, "id_at_location", None)
if result_id_at_location is not None:
# first downcast anything primitive which is not already PyPrimitive
(
downcast_args,
downcast_kwargs,
) = lib.python.util.downcast_args_and_kwargs(args=args, kwargs=kwargs)
# then we convert anything which isnt a pointer into a pointer
pointer_args, pointer_kwargs = pointerize_args_and_kwargs(
args=downcast_args, kwargs=downcast_kwargs, client=__self.client
)
cmd = RunClassMethodAction(
path=attr_path_and_name,
_self=__self,
args=pointer_args,
kwargs=pointer_kwargs,
id_at_location=result_id_at_location,
address=__self.client.address,
)
__self.client.send_immediate_msg_without_reply(msg=cmd)
inherit_tags(
attr_path_and_name=attr_path_and_name,
result=result,
self_obj=__self,
args=args,
kwargs=kwargs,
)
return result
return run_class_method
[docs]def generate_class_property_function(
attr_path_and_name: str, action: PropertyActions, map_to_dyn: bool
) -> CallableT:
"""Returns a function that handles action on property.
Args:
attr_path_and_name: The path of the property in AST.
action: action to perform on property (GET | SET | DEL).
Returns:
Function to handle action on property.
"""
def class_property_function(__self: Any, *args: Any, **kwargs: Any) -> object:
"""Handles remote action on property and returns pointer.
Args:
*args: Argument list.
**kwargs: Keyword arguments.
Returns:
Pointer to the object returned.
"""
# we want to get the return type which matches the attr_path_and_name
# so we ask lib_ast for the return type name that matches out
# attr_path_and_name and then use that to get the actual pointer klass
# then set the result to that pointer klass
return_type_name = __self.client.lib_ast.query(
attr_path_and_name
).return_type_name
resolved_pointer_type = __self.client.lib_ast.query(return_type_name)
result = resolved_pointer_type.pointer_type(client=__self.client)
# QUESTION can the id_at_location be None?
result_id_at_location = getattr(result, "id_at_location", None)
if result_id_at_location is not None:
# first downcast anything primitive which is not already PyPrimitive
(
downcast_args,
downcast_kwargs,
) = lib.python.util.downcast_args_and_kwargs(args=args, kwargs=kwargs)
# then we convert anything which isnt a pointer into a pointer
pointer_args, pointer_kwargs = pointerize_args_and_kwargs(
args=downcast_args, kwargs=downcast_kwargs, client=__self.client
)
cmd = GetOrSetPropertyAction(
path=attr_path_and_name,
id_at_location=result_id_at_location,
address=__self.client.address,
_self=__self,
args=pointer_args,
kwargs=pointer_kwargs,
action=action,
map_to_dyn=map_to_dyn,
)
__self.client.send_immediate_msg_without_reply(msg=cmd)
if action == PropertyActions.GET:
inherit_tags(
attr_path_and_name=attr_path_and_name,
result=result,
self_obj=__self,
args=args,
kwargs=kwargs,
)
return result
return class_property_function
def _get_request_config(self: Any) -> Dict[str, Any]:
"""Get config for request.
Args:
self: object.
Returns:
Config for request.
"""
return {
"request_block": True,
"timeout_secs": 25,
"delete_obj": False,
}
def _set_request_config(self: Any, request_config: Dict[str, Any]) -> None:
"""Set config for request.
Args:
self: object.
request_config: new config.
"""
setattr(self, "get_request_config", lambda: request_config)
[docs]def wrap_iterator(attrs: Dict[str, Union[str, CallableT, property]]) -> None:
"""Add syft Iterator to `attrs['__iter__']`.
Args:
attrs: Dict of `Attribute`s of node.
Raises:
AttributeError: Base `__iter__` is not callable.
"""
def wrap_iter(iter_func: CallableT) -> CallableT:
"""Create syft iterator for `iter_func`.
Args:
iter_func: Base Iterator.
Returns:
Wrapped Iterator.
"""
def __iter__(self: Any) -> Iterable:
"""Create Syft Iterator for `iter_func`.
Args:
self: object to add iterator to.
Raises:
ValueError: Falied ot access __len__.
Returns:
Iterable: syft Iterator.
"""
# syft absolute
from syft.lib.python.iterator import Iterator
if not hasattr(self, "__len__"):
traceback_and_raise(
ValueError(
"Can't build a remote iterator on an object with no __len__."
)
)
try:
data_len = self.__len__()
except Exception:
traceback_and_raise(
ValueError("Request to access data length rejected.")
)
return Iterator(_ref=iter_func(self), max_len=data_len)
return __iter__
attr_name = "__iter__"
iter_target = attrs[attr_name]
# skip if __iter__ has already been wrapped
qual_name = getattr(iter_target, "__qualname__", None)
if qual_name and "wrap_iter" in qual_name:
return
if not callable(iter_target):
traceback_and_raise(AttributeError("Can't wrap a non callable iter attribute"))
else:
iter_func: CallableT = iter_target
attrs[attr_name] = wrap_iter(iter_func)
[docs]def wrap_len(attrs: Dict[str, Union[str, CallableT, property]]) -> None:
"""Add method to access pointer len to `attr[__len__]`.
Args:
attrs: Dict of `Attribute`s of node.
Raises:
AttributeError: Base `__len__` is not callable.
"""
def wrap_len(len_func: CallableT) -> CallableT:
"""Add wrapper function for `len_func`.
Args:
len_func: Base len function.
Returns:
Wrapped len function.
"""
def __len__(self: Any) -> int:
"""Access len of pointer obj.
Args:
self: object to add iterator to.
Returns:
int: length of object.
Raises:
ValueError: Request to access data length rejected.
"""
data_len_ptr = len_func(self)
try:
data_len = data_len_ptr.get(**self.get_request_config())
if data_len is None:
raise Exception
return data_len
except Exception:
traceback_and_raise(
ValueError("Request to access data length rejected.")
)
return __len__
attr_name = "__len__"
len_target = attrs[attr_name]
if not callable(len_target):
traceback_and_raise(
AttributeError("Can't wrap a non callable __len__ attribute")
)
else:
len_func: CallableT = len_target
attrs["len"] = len_func
attrs[attr_name] = wrap_len(len_func)
[docs]def attach_description(obj: object, description: str) -> None:
"""Add description to the object.
Args:
obj: Object to add description to.
description: Description.
Raises:
AttributeError: Cannot add description to object.
"""
try:
obj.description = description # type: ignore
except AttributeError:
warning(f"can't attach new attribute `description` to {type(obj)} object.")
[docs]class Class(Callable):
"""A Class attribute represents a class."""
[docs] def __init__(
self,
path_and_name: str,
parent: ast.attribute.Attribute,
object_ref: Union[Callable, CallableT],
return_type_name: Optional[str],
client: Optional[Any],
) -> None:
"""Base constructor for Class Attribute.
Args:
path_and_name: The path for the current node, e.g. `syft.lib.python.List`.
parent: The parent node is needed when solving `EnumAttributes`.
object_ref: The actual python object for which the computation is being made.
return_type_name: The return type name of given action as a string with its full path.
client: The client for which all computation is being executed.
"""
super().__init__(
path_and_name=path_and_name,
object_ref=object_ref,
return_type_name=return_type_name,
client=client,
parent=parent,
)
if self.path_and_name is not None:
self.pointer_name = self.path_and_name.split(".")[-1] + "Pointer"
@property
def pointer_type(self) -> Union[Callable, CallableT]:
"""Get pointer type of Class Attribute.
Returns:
`pointer_type` of the object.
"""
return getattr(self, self.pointer_name)
[docs] def create_pointer_class(self) -> None:
"""Create pointer type for object."""
attrs: Dict[str, Union[str, CallableT, property]] = {}
for attr_name, attr in self.attrs.items():
attr_path_and_name = getattr(attr, "path_and_name", None)
# attr_path_and_name None
if isinstance(attr, ast.callable.Callable):
attrs[attr_name] = get_run_class_method(attr_path_and_name)
elif isinstance(attr, ast.property.Property):
prop = property(
generate_class_property_function(
attr_path_and_name, PropertyActions.GET, map_to_dyn=False
)
)
prop = prop.setter(
generate_class_property_function(
attr_path_and_name, PropertyActions.SET, map_to_dyn=False
)
)
prop = prop.deleter(
generate_class_property_function(
attr_path_and_name, PropertyActions.DEL, map_to_dyn=False
)
)
attrs[attr_name] = prop
elif isinstance(attr, ast.dynamic_object.DynamicObject):
prop = property(
generate_class_property_function(
attr_path_and_name, PropertyActions.GET, map_to_dyn=True
)
)
prop = prop.setter(
generate_class_property_function(
attr_path_and_name, PropertyActions.SET, map_to_dyn=True
)
)
prop = prop.deleter(
generate_class_property_function(
attr_path_and_name, PropertyActions.DEL, map_to_dyn=True
)
)
attrs[attr_name] = prop
if attr_name == "__len__":
wrap_len(attrs)
if getattr(attr, "return_type_name", None) == "syft.lib.python.Iterator":
wrap_iterator(attrs)
attrs["get_request_config"] = _get_request_config
attrs["set_request_config"] = _set_request_config
attrs["resolve_pointer_type"] = _resolve_pointer_type
fqn = "Pointer"
if self.path_and_name is not None:
fqn = self.path_and_name + fqn
new_class_name = f"syft.proxy.{fqn}"
parts = new_class_name.split(".")
name = parts.pop(-1)
attrs["__name__"] = name
attrs["__module__"] = ".".join(parts)
klass_pointer = type(self.pointer_name, (Pointer,), attrs)
setattr(klass_pointer, "path_and_name", self.path_and_name)
setattr(self, self.pointer_name, klass_pointer)
[docs] def store_init_args(outer_self: Any) -> None:
"""
Stores args and kwargs of outer_self init by wrapping the init method.
"""
def init_wrapper(self: Any, *args: List[Any], **kwargs: Dict[Any, Any]) -> None:
outer_self.object_ref._wrapped_init(self, *args, **kwargs)
self._init_args = args
self._init_kwargs = kwargs
# If _wrapped_init already exists, create_init_method is already called once
# and does not need to wrap __init__ again.
if not hasattr(outer_self.object_ref, "_wrapped_init"):
outer_self.object_ref._wrapped_init = outer_self.object_ref.__init__
outer_self.object_ref.__init__ = init_wrapper
[docs] def create_send_method(outer_self: Any) -> None:
"""Add `send` method to `outer_self.object_ref`."""
def send(
self: Any,
client: Any,
pointable: bool = True,
description: str = "",
tags: Optional[List[str]] = None,
searchable: Optional[bool] = None,
) -> Pointer:
"""Send obj to client and return pointer to the object.
Args:
self: Object to be sent.
client: Client to send object to.
pointable:
description: Description for the object to send.
tags: Tags for the object to send.
Returns:
Pointer to sent object.
Note:
`searchable` is deprecated please use `pointable` in the future.
"""
if searchable is not None:
msg = "`searchable` is deprecated please use `pointable` in future"
warning(msg, print=True)
warnings.warn(
msg,
DeprecationWarning,
)
pointable = searchable
if not hasattr(self, "id"):
try:
self.id = UID()
except AttributeError:
pass
# if `tags` is passed in, use it; else, use obj_tags
obj_tags = getattr(self, "tags", [])
tags = tags if tags else []
tags = tags if tags else obj_tags
# if `description` is passed in, use it; else, use obj_description
obj_description = getattr(self, "description", "")
description = description if description else obj_description
# TODO: Allow Classes to opt out in the AST like Pandas where the properties
# would break their dict attr usage
# Issue: https://github.com/OpenMined/PySyft/issues/5322
if outer_self.pointer_name not in {"DataFramePointer", "SeriesPointer"}:
attach_tags(self, tags)
attach_description(self, description)
id_at_location = UID()
# Step 1: create pointer which will point to result
ptr = getattr(outer_self, outer_self.pointer_name)(
client=client,
id_at_location=id_at_location,
tags=tags,
description=description,
)
ptr._pointable = pointable
if pointable:
ptr.gc_enabled = False
else:
ptr.gc_enabled = True
# Step 2: create message which contains object to send
storable = StorableObject(
id=ptr.id_at_location,
data=self,
tags=tags,
description=description,
search_permissions={VERIFYALL: None} if pointable else {},
)
obj_msg = SaveObjectAction(obj=storable, address=client.address)
# Step 3: send message
client.send_immediate_msg_without_reply(msg=obj_msg)
# Step 4: return pointer
return ptr
aggressive_set_attr(obj=outer_self.object_ref, name="send", attr=send)
[docs] def create_storable_object_attr_convenience_methods(outer_self: Any) -> None:
"""Add methods to set tag and description to `outer_self.object_ref`."""
def tag(self: Any, *tags: Tuple[Any, ...]) -> object:
"""Add tags to object.
Args:
self: object to add tags to.
*tags: List of tags to add.
Returns:
object.
"""
attach_tags(self, tags) # type: ignore
return self
def describe(self: Any, description: str) -> object:
"""Add description to object.
Args:
self: object to add description to.
description: Description to add.
Returns:
object.
"""
attach_description(self, description)
return self
aggressive_set_attr(obj=outer_self.object_ref, name="tag", attr=tag)
aggressive_set_attr(obj=outer_self.object_ref, name="describe", attr=describe)
[docs] def add_path(
self,
path: Union[str, List[str]],
index: int,
return_type_name: Optional[str] = None,
framework_reference: Optional[ModuleType] = None,
is_static: bool = False,
) -> None:
"""The add_path method adds new nodes in AST based on type of current node and type of object to be added.
Args:
path: The node path added in AST, e.g. `syft.lib.python.List` or ["syft", "lib", "python", "List].
index: The associated position in the path for the current node.
framework_reference: The Python framework in which we can resolve same path to obtain Python object.
return_type_name: The return type name of the given action as a string with its full path.
is_static: If the queried object is static, it has to be found on AST itself, not on an existing pointer.
"""
if index >= len(path) or path[index] in self.attrs:
return
_path: List[str] = path.split(".") if isinstance(path, str) else path
attr_ref = getattr(self.object_ref, _path[index])
class_is_enum = isinstance(self.object_ref, EnumMeta)
if (
inspect.isfunction(attr_ref)
or inspect.isbuiltin(attr_ref)
or inspect.ismethod(attr_ref)
or inspect.ismethoddescriptor(attr_ref)
):
super().add_path(_path, index, return_type_name)
if isinstance(attr_ref, Enum) and class_is_enum:
enum_attribute = ast.enum.EnumAttribute(
path_and_name=".".join(_path[: index + 1]),
return_type_name=return_type_name,
client=self.client,
parent=self,
)
setattr(self, _path[index], enum_attribute)
self.attrs[_path[index]] = enum_attribute
elif inspect.isdatadescriptor(attr_ref) or inspect.isgetsetdescriptor(attr_ref):
self.attrs[_path[index]] = ast.property.Property(
path_and_name=".".join(_path[: index + 1]),
object_ref=attr_ref,
return_type_name=return_type_name,
client=self.client,
parent=self,
)
elif not callable(attr_ref):
static_attribute = ast.static_attr.StaticAttribute(
path_and_name=".".join(_path[: index + 1]),
return_type_name=return_type_name,
client=self.client,
parent=self,
)
setattr(self, _path[index], static_attribute)
self.attrs[_path[index]] = static_attribute
def add_dynamic_object(self, path_and_name: str, return_type_name: str) -> None:
self.attrs[
path_and_name.rsplit(".", maxsplit=1)[-1]
] = ast.dynamic_object.DynamicObject(
path_and_name=path_and_name,
return_type_name=return_type_name,
client=self.client,
parent=self,
)
def __getattribute__(self, item: str) -> Any:
"""Get pointer to attribute.
Args:
item: Attribute.
Returns:
Pointer to the attribute.
"""
# self.apply_node_changes()
try:
target_object = super().__getattribute__(item)
if isinstance(target_object, ast.static_attr.StaticAttribute):
return target_object.get_remote_value()
if isinstance(target_object, ast.enum.EnumAttribute):
target_object_ptr = target_object.get_remote_enum_attribute()
target_object_ptr.is_enum = True
return target_object_ptr
return target_object
except Exception as e:
critical(
"__getattribute__ failed. If you are trying to access an EnumAttribute or a "
"StaticAttribute, be sure they have been added to the AST. Falling back on"
"__getattr__ to search in self.attrs for the requested field."
)
traceback_and_raise(e)
def __getattr__(self, item: str) -> Any:
"""Get value of attribute `item` of the object.
Args:
item: Attribute.
Raises:
KeyError: If attribute `item` is not present.
Returns:
Value of the attribute.
"""
attrs = super().__getattribute__("attrs")
if item not in attrs:
if item == "__name__":
# return the pointer name if __name__ is missing
return self.pointer_name
traceback_and_raise(
KeyError(
f"__getattr__ failed, {item} is not present on the "
f"object, nor the AST attributes!"
)
)
return attrs[item]
def __setattr__(self, key: str, value: Any) -> None:
"""Change value of attribute `key` to `value`.
Args:
key: name of attribute to change.
value: value to change attribute `key` to.
"""
# self.apply_node_changes()
if hasattr(super(), "attrs"):
attrs = super().__getattribute__("attrs")
if key in attrs:
target_object = self.attrs[key]
if isinstance(target_object, ast.static_attr.StaticAttribute):
return target_object.set_remote_value(value)
return super().__setattr__(key, value)
[docs]def pointerize_args_and_kwargs(
args: Union[List[Any], Tuple[Any, ...]], kwargs: Dict[Any, Any], client: Any
) -> Tuple[List[Any], Dict[Any, Any]]:
"""Get pointers to args and kwargs.
Args:
args: List of arguments.
kwargs: Dict of Keyword arguments.
client: Client node.
Returns:
Tuple of args and kwargs with pointer to values.
"""
# When we try to send params to a remote function they need to be pointers so
# that they can be serialized and fetched from the remote store on arrival
# this ensures that any args which are passed in from the user side are first
# converted to pointers and sent then the pointer values are used for the
# method invocation
pointer_args = []
pointer_kwargs = {}
for arg in args:
# check if its already a pointer
if not isinstance(arg, Pointer):
arg_ptr = arg.send(client, pointable=False)
pointer_args.append(arg_ptr)
else:
pointer_args.append(arg)
for k, arg in kwargs.items():
# check if its already a pointer
if not isinstance(arg, Pointer):
arg_ptr = arg.send(client, pointable=False)
pointer_kwargs[k] = arg_ptr
else:
pointer_kwargs[k] = arg
return pointer_args, pointer_kwargs