Source code for rif.dtypes

"""overrides numpy operators to work with rif dpypes

override numpy operators using numpy.set_numeric_ops. keep track of
overridden operators and call them if a custom operator is not found

"""
import sys
import numpy as np
import rif
import rif.actor
import rif.eigen_types
import pandas
from rif import V3, M3, X3, Atom
from functools import wraps

_FUNCS_BROKEN_WITH_RIOPS = [
    (np, 'cumsum'),
    (pandas.core.internals.BlockManager, 'delete')
    # (np.ndarray, 'cumsum'),  # can't set attrib of build-in ndarray!
]
_RIFOP_MODULES = [
    rif.eigen_types,
    rif.actor,
]
_NPY_RIF_OP1MAP = dict()
_NPY_RIF_OP2MAP = dict()
_ORIG_NUMPY_OPS = None

_opmap1 = dict(
    abs='absolute'
)

_opmap2 = dict(
    add='add',
    mul='multiply',
    sub='subtract',
    div='divide',
)


def _init_dispatch():
    global _NPY_RIF_OP1MAP
    global _NPY_RIF_OP2MAP
    dtmap = dict(
        fl=(type(1), type(1.0)),
        V3=V3.dtype,
        M3=M3.dtype,
        X3=X3.dtype,
        AT=Atom.dtype,
    )
    dtmap = {k: v if hasattr(v, '__iter__') else (v,)  # x -> (x,) iff not iter
             for k, v in dtmap.items()}
    for module in _RIFOP_MODULES:
        for fn in dir(module):
            if fn.startswith('rifop_'):
                splt = fn.split('_')
                if len(splt) is 3:
                    _, op, t1 = splt
                    for dt1 in dtmap[t1]:
                        k = dt1, _opmap1[op]
                        _NPY_RIF_OP1MAP[k] = getattr(module, fn)
                elif len(splt) is 4:
                    _, op, t1, t2 = splt
                    for dt1 in dtmap[t1]:
                        for dt2 in dtmap[t2]:
                            k = dt1, dt2, _opmap2[op]
                            _NPY_RIF_OP2MAP[k] = getattr(module, fn)


_init_dispatch()


def _get_type_str(t):
    try:
        return str(t.dtype)
    except AttributeError:
        return ''  # scalar


def _override1(name):
    def ufunc(*args, **kwargs):
        try:
            t = args[0].dtype if hasattr(args[0], 'dtype') else type(args[0])
            return _NPY_RIF_OP1MAP[t, name](*args, **kwargs)
        except (AttributeError, KeyError):
            # return getattr(np, name)(x, *args, **kwargs)
            return _ORIG_NUMPY_OPS[name](*args, **kwargs)
    return ufunc


def _override2(name):
    def ufunc(*args, **kwargs):
        try:
            t1 = args[0].dtype if hasattr(args[0], 'dtype') else type(args[0])
            t2 = args[1].dtype if hasattr(args[1], 'dtype') else type(args[1])
            return _NPY_RIF_OP2MAP[t1, t2, name](*args, **kwargs)
        except KeyError:
            # print(name, 'x', x, 'y', y, 'args', args, 'kwargs', kwargs)
            # return getattr(np, name)(x, y, *args, **kwargs)
            return _ORIG_NUMPY_OPS[name](*args, **kwargs)
    return ufunc


def rif_operators_are_enabled():
    return _ORIG_NUMPY_OPS is not None


[docs]def global_rif_operators_enable(quiet=False): "enable rif operators via numpy.set_numeric_opts" global _ORIG_NUMPY_OPS if rif_operators_are_enabled(): print('warning: global_rif_ops is already enabled') else: d1 = {ufunc: _override1(ufunc) for ufunc in _opmap1.values()} d2 = {ufunc: _override2(ufunc) for ufunc in _opmap2.values()} d1.update(d2) _ORIG_NUMPY_OPS = np.set_numeric_ops(**d1) wrap_broken_functions()
[docs]def global_rif_operators_disable(quiet=False): "disable rif operators via numpy.set_numeric_opts" global _ORIG_NUMPY_OPS assert _ORIG_NUMPY_OPS np.set_numeric_ops(**_ORIG_NUMPY_OPS) _ORIG_NUMPY_OPS = None unwrap_broken_functions()
[docs]class RifOperators(object): """contect manager for locally enabling rif ops""" def __enter__(self): self.previously_not_using_rifops = _ORIG_NUMPY_OPS is None if self.previously_not_using_rifops: global_rif_operators_enable() def __exit__(self, *args): if args: print("========== RifOperators: exit ============") for a in args: print(a) print('----------- end rifops exit --------------') if self.previously_not_using_rifops: global_rif_operators_disable()
[docs]class RifOperatorsDisabled(object): """context manager to locally disable RifOperators""" def __enter__(self): self.previously_using_rifops = rif_operators_are_enabled() if self.previously_using_rifops: global_rif_operators_disable() def __exit__(self, *args): if self.previously_using_rifops: global_rif_operators_enable()
def with_rifops_enabled(f): @wraps(f) def wrap(*args, **kwargs): with RifOperators(): return f(*args, **kwargs) return wrap def with_rifops_disabled(f): @wraps(f) def wrap(*args, **kwargs): with RifOperatorsDisabled(): return f(*args, **kwargs) return wrap _ORIG_WRAPPED_BROKEN_FUNCTIONS = None def wrap_broken_functions(): global _ORIG_WRAPPED_BROKEN_FUNCTIONS assert not _ORIG_WRAPPED_BROKEN_FUNCTIONS _ORIG_WRAPPED_BROKEN_FUNCTIONS = dict() for mod, fn in _FUNCS_BROKEN_WITH_RIOPS: _ORIG_WRAPPED_BROKEN_FUNCTIONS[mod.__name__, fn] = getattr(mod, fn) setattr(mod, fn, with_rifops_disabled(getattr(mod, fn))) def unwrap_broken_functions(): global _ORIG_WRAPPED_BROKEN_FUNCTIONS assert _ORIG_WRAPPED_BROKEN_FUNCTIONS is not None for mod, fn in _FUNCS_BROKEN_WITH_RIOPS: setattr(mod, fn, _ORIG_WRAPPED_BROKEN_FUNCTIONS[mod.__name__, fn]) _ORIG_WRAPPED_BROKEN_FUNCTIONS = None # global_rif_operators_enable() # assert rif_operators_are_enabled()