"""
This is the main Plan class which is responsible for containing a list of Actions
which can be serialized, deserialized, and executed by substituting the run time
pointers with the original traced pointers and replaying the actions against a node.
"""
# stdlib
import re
import sys
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
# third party
from google.protobuf.reflection import GeneratedProtocolMessageType
from nacl.signing import VerifyKey
# syft relative
from ... import serialize
from ...logger import traceback_and_raise
from ...proto.core.node.common.action.action_pb2 import Action as Action_PB
from ...proto.core.plan.plan_pb2 import Plan as Plan_PB
from ..common.object import Serializable
from ..common.serde.serializable import bind_protobuf
from ..node.abstract.node import AbstractNode
from ..node.common import client
from ..node.common.action.common import Action
from ..node.common.util import listify
from ..pointer.pointer import Pointer
from ..store.storeable_object import StorableObject
CAMEL_TO_SNAKE_PAT = re.compile(r"(?<!^)(?=[A-Z])")
[docs]@bind_protobuf
class Plan(Serializable):
"""
A plan is a collection of actions, plus some variable inputs, that together form a computation graph.
Attributes:
actions: list of actions
inputs: Pointers to the inputs. Defaults to None.
"""
[docs] def __init__(
self,
actions: Union[List[Action], None] = None,
inputs: Union[Dict[str, Pointer], None] = None,
outputs: Union[Pointer, List[Pointer], None] = None,
i2o_map: Union[Dict[str, int], None] = None,
code: Optional[str] = None,
max_calls: Optional[int] = None,
):
"""
Initialize the Plan with actions, inputs and outputs
"""
self.actions: List[Action] = listify(actions)
self.inputs: Dict[str, Pointer] = inputs if inputs is not None else dict()
self.outputs: List[Pointer] = listify(outputs)
self.i2o_map: Dict[str, int] = i2o_map if i2o_map is not None else dict()
self.code = code
self.max_calls = max_calls
self.n_calls = 0
def __call__(
self,
node: Optional[AbstractNode] = None,
verify_key: VerifyKey = None,
**kwargs: Dict[str, Any],
) -> List[StorableObject]:
"""
1) For all pointers that were passed into the init as `inputs`, this method
replaces those pointers in self.actions by the pointers passed in as *args.
2) Executes the actions in self.actions one by one
*While this function requires `node` and `verify_key` as inputs, during remote
execution, passing these is handled in `RunClassMethodAction`*
*Note that this method will receive *args as pointers during execution. Normally,
pointers are resolved during `RunClassMethodAction.execute()`, but not for plans,
as they need to operate on the pointer to enable remapping of the inputs.*
Args:
*args: the new inputs for the plan, passed as pointers
"""
self.n_calls += 1
# this is pretty cumbersome, we are searching through all actions to check
# if we need to redefine some of their attributes that are inputs in the
# graph of actions
if node is None:
return self.execute_locally(**kwargs)
new_inputs: Dict[str, Pointer] = {}
for k, current_input in self.inputs.items():
new_input = kwargs[k]
if not issubclass(type(new_input), Pointer):
traceback_and_raise(
f"Calling Plan without a Pointer. {k} == {type(new_input)} "
)
for a in self.actions:
if hasattr(a, "remap_input"):
a.remap_input(current_input, new_input) # type: ignore
# redefine the inputs of the plan
new_inputs[k] = new_input # type: ignore
self.inputs = new_inputs
for a in self.actions:
a.execute_action(node, verify_key)
for k, v in self.i2o_map.items():
self.outputs[v] = self.inputs[k]
if len(self.outputs):
resolved_outputs = []
for arg in self.outputs:
r_arg = node.store[arg.id_at_location]
resolved_outputs.append(r_arg.data)
return resolved_outputs
else:
return []
def __repr__(self) -> str:
obj_str = "Plan"
allowed, remaining = (
(self.max_calls, self.max_calls - self.n_calls)
if self.max_calls is not None
else ("not defined", "not defined")
)
ex_str = f"Allowed executions:\t{allowed}\nRemaining executions:\t{remaining}"
inp_str = "Inputs:\n"
inp_str += "\n".join(
[f"\t\t{k}:\t{v.__class__.__name__}" for k, v in self.inputs.items()]
)
act_str = f"Actions:\n\t\t{len(self.actions)} Actions"
out_str = "Outputs:\n"
out_str += "\n".join([f"\t\t{o.__class__.__name__}" for o in self.outputs])
plan_str = "Plan code:\n"
plan_str += f'"""\n{self.code}\n"""' if self.code is not None else ""
return f"{obj_str}\n{ex_str}\n{inp_str}\n{act_str}\n{out_str}\n\n{plan_str}"
[docs] def execute_locally(self, **kwargs: Any) -> List[StorableObject]:
"""Execute a plan by sending it to a virtual machine and calling execute on the pointer.
This is a workaround until we have a way to execute plans locally.
"""
# prevent circular dependency
# syft relative
from ...core.node.vm.vm import VirtualMachine # noqa: F401
alice = VirtualMachine(name="plan_executor")
alice_client: client.Client = alice.get_client()
self_ptr = self.send(alice_client) # type: ignore
out = self_ptr(**kwargs)
return out.get()
[docs] @staticmethod
def get_protobuf_schema() -> GeneratedProtocolMessageType:
"""Return the type of protobuf object which stores a class of this type
As a part of serialization and deserialization, we need the ability to
lookup the protobuf object type directly from the object type. This
static method allows us to do this.
Importantly, this method is also used to create the reverse lookup ability within
the metaclass of Serializable. In the metaclass, it calls this method and then
it takes whatever type is returned from this method and adds an attribute to it
with the type of this class attached to it. See the MetaSerializable class for details.
:return: the type of protobuf object which corresponds to this class.
:rtype: GeneratedProtocolMessageType
"""
return Plan_PB
def _object2proto(self) -> Plan_PB:
"""Returns a protobuf serialization of self.
As a requirement of all objects which inherit from Serializable,
this method transforms the current object into the corresponding
Protobuf object so that it can be further serialized.
:return: returns a protobuf object
:rtype: ObjectWithID_PB
.. note::
This method is purely an internal method. Please use object.serialize() or one of
the other public serialization methods if you wish to serialize an
object.
"""
def camel_to_snake(s: str) -> str:
"""Convert CamelCase classes to snake case for matching protobuf names"""
return CAMEL_TO_SNAKE_PAT.sub("_", s).lower()
actions_pb = [
Action_PB(
obj_type=".".join([action.__module__, action.__class__.__name__]),
**{camel_to_snake(action.__class__.__name__): serialize(action)},
)
for action in self.actions
]
inputs_pb = {k: v._object2proto() for k, v in self.inputs.items()}
outputs_pb = [out._object2proto() for out in self.outputs]
i2o_map_pb = self.i2o_map
return Plan_PB(
actions=actions_pb, inputs=inputs_pb, outputs=outputs_pb, i2o_map=i2o_map_pb
)
@staticmethod
def _proto2object(proto: Plan_PB) -> "Plan":
"""Creates a ObjectWithID from a protobuf
As a requirement of all objects which inherit from Serializable,
this method transforms a protobuf object into an instance of this class.
:return: returns an instance of Plan
:rtype: Plan
.. note::
This method is purely an internal method. Please use syft.deserialize()
if you wish to deserialize an object.
"""
actions = []
for action_proto in proto.actions:
module, cls_name = action_proto.obj_type.rsplit(".", 1)
action_cls = getattr(sys.modules[module], cls_name)
# protobuf does no inheritance, so we wrap action subclasses
# in the main action class.
inner_action = getattr(action_proto, action_proto.WhichOneof("action"))
actions.append(action_cls._proto2object(inner_action))
inputs = {k: Pointer._proto2object(proto.inputs[k]) for k in proto.inputs}
outputs = [
Pointer._proto2object(pointer_proto) for pointer_proto in proto.outputs
]
i2o_map = proto.i2o_map
return Plan(actions=actions, inputs=inputs, outputs=outputs, i2o_map=i2o_map)