Source code for rain.client.session


"""
Global stack of active sessions
Do not directly acces to this array
but use

>>> with session:
...     pass

or function

get_active_session()
"""

from rain.client import rpc
from ..common import RainException, ID
from . import graph

_global_sessions = []
_default_session = None

# TODO: Check attribute "active" before making remote calls


def global_session_push(session):
    global _global_sessions
    if not session.active:
        raise RainException("Session is closed")
    _global_sessions.append(session)


def global_session_pop():
    global _global_sessions
    return _global_sessions.pop()


class SessionBinder:
    """This class is returned when session.bind_only() is used"""

    def __init__(self, session):
        self.session = session

    def __enter__(self):
        global_session_push(self.session)
        return self.session

    def __exit__(self, type, value, traceback):
        s = global_session_pop()
        assert self.session is s


[docs]class Session: """ A container for one task graph. Do not create directly, rather using :func:`Client.new_session`. When used as a context manager, all new objects and tasks are created within the session. Note the session is closed afterwards. >>> with client.new_session() as s: ... bl = blob("Hello rain!") ... tsk = tasks.sleep(1.0, bl) ... tsk.output.keep() ... s.submit() ... print(tsk.output.fetch()) # waits for completion Currently, the graph and objects are alive on the server only as long as the `Session` exists. """ def __init__(self, client, session_id, default=False): self.active = True # True if a session is live in server self.client = client self.session_id = session_id self._tasks = [] # Unsubmitted task self._dataobjs = [] # Unsubmitted objects self._id_counter = 9 self._submitted_tasks = [] self._submitted_dataobjs = [] # Cache for not submited constants: bytes/str -> DataObject # It is cleared on submit # TODO: It is not now implemented self._const_cache = {} # Static data serves for internal usage of client. # It is not directly available to user # It is used to store e.g. for serialized Python objects self._static_data = {} if default: self.set_as_default() @property def task_count(self): """The number of unsubmitted tasks.""" return len(self._tasks) @property def dataobj_count(self): """The number of unsubmitted objects.""" return len(self._dataobjs) def __enter__(self): global_session_push(self) return self def __exit__(self, type, value, traceback): s = global_session_pop() assert s is self self.close() def __repr__(self): return "<Session session_id={}>".format(self.session_id)
[docs] def close(self): """Closes session; all tasks are stopped, all objects freed.""" if self.active and self.client: self.client._close_session(self) self._tasks = [] self._dataobjs = [] self._submitted_dataobjs = [] self._submitted_dataobjs = [] self.active = False global _default_session if _default_session == self: _default_session = None
[docs] def bind_only(self): """ This method serves to bind session without autoclose functionality. >>> with session.bind_only() as s: ... doSometing() binds the session, but do not close it at the end (so it may be bound again either with `bind_only` or normally with `with session: ...`). """ return SessionBinder(self)
def _register_task(self, task): """Register task into session. Returns: ID: the assigned id.""" assert task._session == self and task.id is None self._tasks.append(task) self._id_counter += 1 return ID(session_id=self.session_id, id=self._id_counter) def _register_dataobj(self, dataobj): """Register data object into session. Returns: ID: the assigned id.""" assert dataobj._session == self and dataobj.id is None self._dataobjs.append(dataobj) self._id_counter += 1 return ID(session_id=self.session_id, id=self._id_counter)
[docs] def keep_all(self): """Set keep flag for all unsubmitted objects""" for dataobj in self._dataobjs: dataobj.keep()
[docs] def submit(self): """"Submit all unsubmitted objects.""" self.client._submit(self._tasks, self._dataobjs) for task in self._tasks: task._state = rpc.common.TaskState.notAssigned self._submitted_tasks.append(task) for dataobj in self._dataobjs: dataobj._state = rpc.common.DataObjectState.unfinished self._submitted_dataobjs.append(dataobj) self._tasks = [] self._dataobjs = []
def _split_tasks_objects(self, items): """Split `items` into `Task`s and `DataObject`s, raisong error on anything else. Returns: `(tasks, dataobjs)`""" from . import Task, DataObject tasks, dataobjs = [], [] for i in items: if isinstance(i, Task): tasks.append(i) elif isinstance(i, DataObject): dataobjs.append(i) else: raise TypeError("Neither Task or DataObject: {!r}".format(i)) return (tasks, dataobjs)
[docs] def wait(self, items): """Wait until *all* specified tasks and dataobjects are finished.""" tasks, dataobjs = self._split_tasks_objects(items) self.client._wait(tasks, dataobjs) for task in tasks: task._state = rpc.common.TaskState.finished for dataobj in dataobjs: dataobj._state = rpc.common.DataObjectState.finished
[docs] def wait_some(self, items): """Wait until *some* of specified tasks/dataobjects are finished. Returns: `(finished_tasks, finished_dataobjs)`""" tasks, dataobjs = self._split_tasks_objects(items) finished_tasks, finished_dataobjs = self.client._wait_some( tasks, dataobjs) for task in finished_tasks: task._state = rpc.common.TaskState.finished for dataobj in finished_dataobjs: dataobj._state = rpc.common.DataObjectState.finished return finished_tasks, finished_dataobjs
[docs] def wait_all(self): """Wait until all submitted tasks and objects are finished.""" self.client._wait_all(self) for task in self._submitted_tasks: task._state = rpc.common.TaskState.finished for dataobj in self._submitted_dataobjs: dataobj._state = rpc.common.DataObjectState.finished
[docs] def fetch(self, dataobject): """Wait for the object to finish, update its state and fetch the object data. Returns: `DataInstance`: The object data proxy.""" return self.client._fetch(dataobject)
[docs] def unkeep(self, dataobjects): """Unset keep flag for given objects.""" submitted = [] from . import DataObject for dataobj in dataobjects: if not isinstance(dataobj, DataObject): raise TypeError("Not a DataObject: {!r}".format(dataobj)) if not dataobj.is_kept(): raise RainException("Object {} is not kept".format(dataobj.id)) if dataobj.state is not None: submitted.append(dataobj) else: dataobj._keep = False if not submitted: return self.client._unkeep(submitted) for dataobj in submitted: dataobj._free()
[docs] def update(self, items): """Update the status and metadata of given tasks and objects.""" self.client.update(items)
def set_as_default(self): global _default_session _default_session = self
[docs] def make_graph(self, show_ids=True): """Create a graph of tasks and objects that were *not yet* submitted.""" def add_obj(o): if o is None: return node = g.node(o) node.label = o.id node.shape = "box" node.color = "none" node.fillcolor = "#0088aa" node.fillcolor = "#44ccff" if o.is_kept(): node.fillcolor = "#44ccff" node.color = "black" def add_task(t): if t is None: return node = g.node(t) node.label = "{}\n{}".format(t.id_pair, t.task_type) node.shape = "oval" node.fillcolor = "#0088aa" node.color = "none" for i, (key, o) in enumerate(t.inputs.items()): if key is None: label = str(i) else: label = "{}: {}".format(i, key) g.node(o).add_arc(node, label) for i, (key, o) in enumerate(t.outputs.items()): if key is None: label = str(i) else: label = "{}: {}".format(i, key) node.add_arc(g.node(o), label) g = graph.Graph() for o in self._dataobjs: add_obj(o) for o in self._submitted_dataobjs: add_obj(o) for t in self._tasks: add_task(t) for t in self._submitted_tasks: add_task(t) return g
def get_active_session(): """Internal helper to get innermost active `Session`.""" if not _global_sessions: if _default_session: return _default_session else: raise RainException("No active session") else: return _global_sessions[-1]