Source code for syft.ast.attribute

"""This module contains Attribute, an interface of a generic node in the AST."""

# stdlib
from types import ModuleType
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

# syft relative
from .. import ast
from ..core.node.abstract.node import AbstractNodeClient
from ..logger import traceback_and_raise


[docs]class Attribute: """Attribute is the interface of a generic node in the AST that covers basic functionality.""" __slots__ = [ "path_and_name", "object_ref", "attrs", "return_type_name", "client", "_parent", ] lookup_cache: Dict[Any, Any] = {}
[docs] def __init__( self, client: Optional[AbstractNodeClient], path_and_name: Optional[str] = None, object_ref: Any = None, return_type_name: Optional[str] = None, parent: Optional["Attribute"] = None, ): """Base constructor for all AST nodes. Args: client: The client for which all computation is being executed. path_and_name: The path for the current node, e.g. `syft.lib.python.List`. object_ref: The actual python object for which the computation is being made. return_type_name: The given action's return type name, with its full path, in string format. parent: The parent node in the AST. """ self.client: Optional[AbstractNodeClient] = client self.path_and_name: Optional[str] = path_and_name self.object_ref: Any = object_ref self.return_type_name: Optional[str] = return_type_name # The `attrs` are the nodes that have the current node as a parent node # maps from the name on the path ot the actual attribute. self.attrs: Dict[str, "Attribute"] = {} self._parent: Optional["Attribute"] = parent
def __call__( self, path: Union[List[str], str], index: int = 0, obj_type: Optional[type] = None, ) -> Any: """Execute the given node object reference with the given parameters. Args: path: The node path in AST to execute, e.g. `syft.lib.python.List` or ["syft", "lib", "python", "List]. index: The associated position in the path for the current node. obj_type: The type of the object to be called, whose path is resolved from the `lookup_cache`. Returns: The results of running the computation on the object ref. """ traceback_and_raise(NotImplementedError) def _extract_attr_type( self, container: Union[ List["ast.klass.Class"], List["ast.module.Module"], List["ast.property.Property"], ], field: str, ) -> None: """Helper function to extract a class of nodes whose parent is the current node. Args: container: A list of objects in which we want to store the results. field: The typeof attribute from the current node's `attrs`. """ for ref in self.attrs.values(): sub_prop = getattr(ref, field, None) if sub_prop is None: continue container.extend(sub_prop) @property def classes(self) -> List["ast.klass.Class"]: """Extract all classes from the current node attributes. Returns: The list of classes in the current AST node attributes. """ out = [] if isinstance(self, ast.klass.Class): out.append(self) self._extract_attr_type(out, "classes") return out @property def properties(self) -> List["ast.property.Property"]: """Extract all properties from the current node attributes. Returns: The list of properties in the current AST node attributes. """ out = [] if isinstance(self, ast.property.Property): out.append(self) self._extract_attr_type(out, "properties") return out
[docs] def query( self, path: Union[List[str], str], obj_type: Optional[type] = None ) -> "Attribute": """The query method is a tree traversal function based on the path to retrieve the node. It has a similar functionality to `__call__`, main difference being that `query` retrieves node without performing execution on node. Args: path: The node path in AST to be queried, e.g. `syft.lib.python.List` or ["syft", "lib", "python", "List"]. obj_type: The type of the object that we want to call, whose path is resolved from the `lookup_cache`. Returns: The attribute in the AST at the given initial path. """ # TODO: fix hacky work around if path == "syft.lib.python.list.List": path = "syft.lib.python.List" if obj_type is not None: # If the searched given type has already been seen, resolve it with the path from `lookup_cache`. if obj_type in self.lookup_cache: path = self.lookup_cache[obj_type] _path = path if isinstance(path, list) else path.split(".") if len(_path) == 0: return self # If the first element of the path is a child node, continue the query in the child node if _path[0] in self.attrs: return self.attrs[_path[0]].query(path=_path[1:]) traceback_and_raise( ValueError(f"Path {'.'.join(_path)} not present in the AST.") )
@property def name(self) -> str: """Retrieve the name of the current AST node from its `path_and_name`. Returns: The name of the current attribute. """ path_and_name = self.path_and_name if self.path_and_name else "" return path_and_name.rsplit(".", maxsplit=1)[-1]
[docs] def add_path( self, path: List[str], index: int, return_type_name: Optional[str] = None, framework_reference: Optional[ModuleType] = None, is_static: bool = False, ) -> None: """The add_path method adds new nodes in AST based on type of current node and type of object to be added. Args: path: The node path added in AST, e.g. `syft.lib.python.List` or ["syft", "lib", "python", "List]. index: The associated position in the path for the current node. framework_reference: The Python framework in which we can resolve same path to obtain Python object. return_type_name: The return type name of the given action as a string with its full path. is_static: If the queried object is static, it has to be found on AST itself, not on an existing pointer. """ traceback_and_raise(NotImplementedError)
[docs] def fetch_live_object(self) -> Any: """Get the new object and its attributes from the client.""" return getattr(self.parent.object_ref, self.name)
[docs] def object_change(self) -> bool: """Check if client wants to change any nodes in the AST with a new object.""" return id(self.fetch_live_object()) != id(self.object_ref)
[docs] def reconstruct_node(self) -> None: """Changes node reference in the AST by adding the new object's reference as specified by client.""" self.object_ref = self.fetch_live_object()
[docs] def apply_node_changes(self) -> None: """Apply the changes in the nodes in the AST as specified by the client.""" if self._parent and self.object_change(): self.reconstruct_node()
@property def parent(self) -> "Attribute": """Check if all the nodes have a parent node. Returns: Attribute: parent node Raises: AttributeError: If node has no parent attribute. """ if self._parent: return self._parent raise AttributeError(f"Node {self} in the AST has not parent attribute set!")