Source code for sfepy.terms.terms_multilinear

import numpy as nm

try:
    import dask.array as da

except ImportError:
    da = None

try:
    import opt_einsum as oe

except ImportError:
    oe = None

try:
    from jax.config import config
    config.update("jax_enable_x64", True)
    import jax
    import jax.numpy as jnp

except ImportError:
    jnp = jax = None

from pyparsing import (Word, Suppress, oneOf, OneOrMore, delimitedList,
                       Combine, alphas, alphanums, Literal)

from sfepy.base.base import output, Struct
from sfepy.base.timing import Timer
from sfepy.mechanics.tensors import dim2sym
from sfepy.terms.terms import Term

def _get_char_map(c1, c2):
    mm = {}
    for ic, char in enumerate(c1):
        if char in mm:
            print(char, '->eq?', mm[char], c2[ic])
            if mm[char] != c2[ic]:
                mm[char] += c2[ic]
        else:
            mm[char] = c2[ic]

    return mm

[docs]def collect_modifiers(modifiers): def _collect_modifiers(toks): if len(toks) > 1: out = [] modifiers.append([]) for ii, mod in enumerate(toks[::3]): tok = toks[3*ii+1] tok = tok.replace(tok[0], toks[2]) modifiers[-1].append(list(toks)) out.append(tok) return out else: modifiers.append(None) return toks return _collect_modifiers
[docs]def parse_term_expression(texpr): mods = 's' lparen, rparen = map(Suppress, '()') simple_arg = Word(alphanums + '.:') arrow = Literal('->').suppress() letter = Word(alphas, exact=1) mod_arg = oneOf(mods) + lparen + simple_arg + rparen + arrow + letter arg = OneOrMore(simple_arg ^ mod_arg) modifiers = [] arg.setParseAction(collect_modifiers(modifiers)) parser = delimitedList(Combine(arg)) eins = parser.parseString(texpr, parseAll=True) return eins, modifiers
[docs]def append_all(seqs, item, ii=None): if ii is None: for seq in seqs: seq.append(item) else: seqs[ii].append(item)
[docs]def get_sizes(indices, operands): sizes = {} for iis, op in zip(indices, operands): for ii, size in zip(iis, op.shape): sizes[ii] = size return sizes
[docs]def get_output_shape(out_subscripts, subscripts, operands): return tuple(get_sizes(subscripts, operands)[ii] for ii in out_subscripts)
[docs]def find_free_indices(indices): ii = ''.join(indices) ifree = [c for c in set(ii) if ii.count(c) == 1] return ifree
[docs]def get_loop_indices(subs, loop_index): return [indices.index(loop_index) if loop_index in indices else None for indices in subs]
[docs]def get_einsum_ops(eargs, ebuilder, expr_cache): dargs = {arg.name : arg for arg in eargs} operands = [[] for ia in range(ebuilder.n_add)] for ia in range(ebuilder.n_add): for io, oname in enumerate(ebuilder.operand_names[ia]): arg_name, val_name = oname.split('.') arg = dargs[arg_name] if val_name == 'dofs': step_cache = arg.arg.evaluate_cache.setdefault('dofs', {}) cache = step_cache.setdefault(0, {}) op = arg.get_dofs(cache, expr_cache, oname) elif val_name == 'I': op = ebuilder.make_eye(arg.n_components) elif val_name == 'Psg': op = ebuilder.make_psg(arg.dim) else: op = dargs[arg_name].get( val_name, msg_if_none='{} has no attribute {}!' .format(arg_name, val_name) ) ics = ebuilder.components[ia][io] if len(ics): iis = [slice(None)] * 2 iis += [slice(None) if ic is None else ic for ic in ics] op = op[tuple(iis)] operands[ia].append(op) return operands
[docs]def get_slice_ops(subs, ops, loop_index): ics = get_loop_indices(subs, loop_index) def slice_ops(ic): sops = [] for ii, icol in enumerate(ics): op = ops[ii] if icol is not None: slices = tuple(slice(None, None) if isub != icol else ic for isub in range(op.ndim)) sops.append(op[slices]) else: sops.append(op) return sops return slice_ops
[docs]class ExpressionArg(Struct):
[docs] @staticmethod def from_term_arg(arg, term, cache): from sfepy.discrete import FieldVariable if isinstance(arg, ExpressionArg): return arg if isinstance(arg, FieldVariable): ag, _ = term.get_mapping(arg) bf = ag.bf key = 'bf{}'.format(id(bf)) _bf = cache.get(key) if bf.shape[0] > 1: # cell-depending basis. if _bf is None: _bf = bf[:, :, 0] cache[key] = _bf else: if _bf is None: _bf = bf[0, :, 0] cache[key] = _bf if isinstance(arg, FieldVariable) and arg.is_virtual(): ag, _ = term.get_mapping(arg) obj = ExpressionArg(name=arg.name, arg=arg, bf=_bf, bfg=ag.bfg, det=ag.det[..., 0, 0], n_components=arg.n_components, dim=arg.dim, kind='virtual') elif isinstance(arg, FieldVariable) and arg.is_state_or_parameter(): ag, _ = term.get_mapping(arg) conn = arg.field.get_econn(term.get_dof_conn_type(), term.region) shape = (ag.n_el, arg.n_components, ag.bf.shape[-1]) obj = ExpressionArg(name=arg.name, arg=arg, bf=_bf, bfg=ag.bfg, det=ag.det[..., 0, 0], region_name=term.region.name, conn=conn, shape=shape, n_components=arg.n_components, dim=arg.dim, kind='state') elif isinstance(arg, nm.ndarray): aux = term.get_args() # Find arg in term arguments using a loop (numpy arrays cannot be # compared) to get its name. ii = [ii for ii in range(len(term.args)) if aux[ii] is arg][0] obj = ExpressionArg(name='_'.join(term.arg_names[ii]), arg=arg, kind='ndarray') elif isinstance(arg, tuple) and isinstance(arg[0], nm.ndarray): obj = ExpressionArg(name=arg[1], arg=arg[0], kind='ndarray') else: raise ValueError('unknown argument type! ({})'.format(type(arg))) return obj
[docs] def get_dofs(self, cache, expr_cache, oname): if self.kind != 'state': return key = (self.name, self.region_name) dofs = cache.get(key) if dofs is None: arg = self.arg dofs_vec = self.arg().reshape((-1, arg.n_components)) # # axis 0: cells, axis 1: node, axis 2: component # dofs = dofs_vec[conn] # axis 0: cells, axis 1: component, axis 2: node dofs = dofs_vec[self.conn].transpose((0, 2, 1)) if arg.n_components == 1: dofs.shape = (dofs.shape[0], -1) cache[key] = dofs # New dofs -> clear dofs from expression cache. for key in list(expr_cache.keys()): if isinstance(key, tuple) and (key[0] == oname): expr_cache.pop(key) return dofs
[docs]class ExpressionBuilder(Struct): letters = 'defgh' _aux_letters = 'rstuvwxyz' def __init__(self, n_add, cache): self.n_add = n_add self.subscripts = [[] for ia in range(n_add)] self.operand_names = [[] for ia in range(n_add)] self.components = [[] for ia in range(n_add)] self.out_subscripts = ['c' for ia in range(n_add)] self.ia = 0 self.cache = cache self.aux_letters = iter(self._aux_letters)
[docs] def make_eye(self, size): key = 'I{}'.format(size) ee = self.cache.get(key) if ee is None: ee = nm.eye(size) self.cache[key] = ee return ee
[docs] def make_psg(self, dim): key = 'Psg{}'.format(dim) psg = self.cache.get(key) if psg is None: sym = dim2sym(dim) psg = nm.zeros((dim, dim, sym)) if dim == 3: psg[0, [0,1,2], [0,3,4]] = 1 psg[1, [0,1,2], [3,1,5]] = 1 psg[2, [0,1,2], [4,5,2]] = 1 elif dim == 2: psg[0, [0,1], [0,2]] = 1 psg[1, [0,1], [2,1]] = 1 self.cache[key] = psg return psg
[docs] def add_constant(self, name, cname): append_all(self.subscripts, 'cq') append_all(self.operand_names, '.'.join((name, cname))) append_all(self.components, [])
[docs] def add_bfg(self, iin, ein, name): append_all(self.subscripts, 'cq{}{}'.format(ein[2], iin)) append_all(self.operand_names, name + '.bfg') append_all(self.components, [])
[docs] def add_bf(self, iin, ein, name, cell_dependent=False): if cell_dependent: append_all(self.subscripts, 'cq{}'.format(iin)) else: append_all(self.subscripts, 'q{}'.format(iin)) append_all(self.operand_names, name + '.bf') append_all(self.components, [])
[docs] def add_eye(self, iic, ein, name, iia=None): append_all(self.subscripts, '{}{}'.format(ein[0], iic), ii=iia) append_all(self.operand_names, '{}.I'.format(name), ii=iia) append_all(self.components, [])
[docs] def add_psg(self, iic, ein, name, iia=None): append_all(self.subscripts, '{}{}{}'.format(iic, ein[2], ein[0]), ii=iia) append_all(self.operand_names, name + '.Psg', ii=iia) append_all(self.components, [])
[docs] def add_arg_dofs(self, iin, ein, name, n_components, iia=None): if n_components > 1: #term = 'c{}{}'.format(iin, ein[0]) term = 'c{}{}'.format(ein[0], iin) else: term = 'c{}'.format(iin) append_all(self.subscripts, term, ii=iia) append_all(self.operand_names, name + '.dofs', ii=iia) append_all(self.components, [])
[docs] def add_virtual_arg(self, arg, ii, ein, modifier): iin = self.letters[ii] # node (qs basis index) if ('.' in ein) or (':' in ein): # derivative, symmetric gradient self.add_bfg(iin, ein, arg.name) else: self.add_bf(iin, ein, arg.name) out_letters = iin if arg.n_components > 1: iic = next(self.aux_letters) # component if ':' not in ein: self.add_eye(iic, ein, arg.name) else: # symmetric gradient if modifier[0][0] == 's': # vector storage self.add_psg(iic, ein, arg.name) else: raise ValueError('unknown argument modifier! ({})' .format(modifier)) out_letters = iic + out_letters for iia in range(self.n_add): self.out_subscripts[iia] += out_letters
[docs] def add_state_arg(self, arg, ii, ein, modifier, diff_var): iin = self.letters[ii] # node (qs basis index) if ('.' in ein) or (':' in ein): # derivative, symmetric gradient self.add_bfg(iin, ein, arg.name) else: self.add_bf(iin, ein, arg.name) out_letters = iin if (diff_var != arg.name): if ':' not in ein: self.add_arg_dofs(iin, ein, arg.name, arg.n_components) else: # symmetric gradient if modifier[0][0] == 's': # vector storage iic = next(self.aux_letters) # component self.add_psg(iic, ein, arg.name) self.add_arg_dofs(iin, [iic], arg.name, arg.n_components) else: raise ValueError('unknown argument modifier! ({})' .format(modifier)) else: if arg.n_components > 1: iic = next(self.aux_letters) # component if ':' in ein: # symmetric gradient if modifier[0][0] != 's': # vector storage raise ValueError('unknown argument modifier! ({})' .format(modifier)) out_letters = iic + out_letters for iia in range(self.n_add): if iia != self.ia: self.add_arg_dofs(iin, ein, arg.name, arg.n_components, iia) elif arg.n_components > 1: if ':' not in ein: self.add_eye(iic, ein, arg.name, iia) else: self.add_psg(iic, ein, arg.name, iia) self.out_subscripts[self.ia] += out_letters self.ia += 1
[docs] def add_material_arg(self, arg, ii, ein): append_all(self.components, []) rein = [] for ii, ie in enumerate(ein): if str.isnumeric(ie): for comp in self.components: comp[-1].append(int(ie)) else: for comp in self.components: comp[-1].append(None) rein.append(ie) rein = ''.join(rein) append_all(self.subscripts, 'cq{}'.format(rein)) append_all(self.operand_names, arg.name + '.arg')
[docs] def build(self, texpr, *args, diff_var=None): eins, modifiers = parse_term_expression(texpr) # Virtual variable must be the first variable. # Numpy arrays cannot be compared -> use a loop. for iv, arg in enumerate(args): if arg.kind == 'virtual': self.add_constant(arg.name, 'det') self.add_virtual_arg(arg, iv, eins[iv], modifiers[iv]) break else: iv = -1 for ip, arg in enumerate(args): if arg.kind == 'state': self.add_constant(arg.name, 'det') break else: raise ValueError('no FieldVariable in arguments!') for ii, ein in enumerate(eins): if ii == iv: continue arg = args[ii] if arg.kind == 'ndarray': self.add_material_arg(arg, ii, ein) elif arg.kind == 'state': self.add_state_arg(arg, ii, ein, modifiers[ii], diff_var) else: raise ValueError('unknown argument type! ({})' .format(type(arg))) for ia, subscripts in enumerate(self.subscripts): ifree = [ii for ii in find_free_indices(subscripts) if ii not in self.out_subscripts[ia]] if ifree: self.out_subscripts[ia] += ''.join(ifree)
[docs] @staticmethod def join_subscripts(subscripts, out_subscripts): return ','.join(subscripts) + '->' + out_subscripts
[docs] def get_expressions(self, subscripts=None): if subscripts is None: subscripts = self.subscripts expressions = [self.join_subscripts(subscripts[ia], self.out_subscripts[ia]) for ia in range(self.n_add)] return tuple(expressions)
[docs] def print_shapes(self, subscripts, operands): if subscripts is None: subscripts = self.subscripts output('number of expressions:', self.n_add) for onames, outs, subs, ops in zip( self.operand_names, self.out_subscripts, subscripts, operands, ): sizes = get_sizes(subs, ops) output(sizes) out_shape = get_output_shape(outs, subs, ops) output(outs, out_shape, '=') for name, ii, op in zip(onames, subs, ops): output(' {:10} {:8} {}'.format(name, ii, op.shape))
[docs] def apply_layout(self, layout, operands, defaults=None, verbosity=0): if layout == 'cqgvd0': return self.subscripts, operands if defaults is None: defaults = { 'det' : 'cq', 'bf' : ('qd', 'cqd'), 'bfg' : 'cqgd', 'dofs' : ('cd', 'cvd'), 'mat' : 'cq', } mat_range = ''.join([str(ii) for ii in range(10)]) new_subscripts = [subs.copy() for subs in self.subscripts] new_operands = [ops.copy() for ops in operands] for ia in range(self.n_add): for io, (oname, subs, op) in enumerate(zip(self.operand_names[ia], self.subscripts[ia], operands[ia])): arg_name, val_name = oname.split('.') if val_name in ('det','bfg'): default = defaults[val_name] elif val_name in ('bf', 'dofs'): default = defaults[val_name][op.ndim - 2] elif val_name in ('I', 'Psg'): default = layout.replace('0', '') # -> Do nothing. else: default = defaults['mat'] + mat_range[:(len(subs) - 2)] if '0' in default: # Material inew = nm.array([default.find(il) for il in layout.replace('0', default[2:]) if il in default]) else: inew = nm.array([default.find(il) for il in layout if il in default]) new = ''.join([default[ii] for ii in inew]) if verbosity > 2: output(arg_name, val_name, subs, default, op.shape, layout) output(inew, new) if new == default: new_subscripts[ia][io] = subs new_operands[ia][io] = op else: new_subs = ''.join([subs[ii] for ii in inew]) if val_name == 'dofs': key = (oname,) + tuple(inew) else: # id is unique only during object lifetime! key = (id(op),) + tuple(inew) new_op = self.cache.get(key) if new_op is None: new_op = op.transpose(inew).copy() self.cache[key] = new_op new_subscripts[ia][io] = new_subs new_operands[ia][io] = new_op if verbosity > 2: output('->', new_subscripts[ia][io]) return new_subscripts, new_operands
[docs] def transform(self, subscripts, operands, transformation='loop', **kwargs): if transformation == 'loop': expressions, poperands, all_slice_ops, loop_sizes = [], [], [], [] loop_index = kwargs.get('loop_index', 'c') for ia, (subs, out_subscripts, ops) in enumerate(zip( subscripts, self.out_subscripts, operands )): slice_ops = get_slice_ops(subs, ops, loop_index) tsubs = [ii.replace(loop_index, '') for ii in subs] tout_subs = out_subscripts.replace(loop_index, '') expr = self.join_subscripts(tsubs, tout_subs) pops = slice_ops(0) expressions.append(expr) poperands.append(pops) all_slice_ops.append(slice_ops) loop_sizes.append(get_sizes(subs, ops)[loop_index]) return tuple(expressions), poperands, all_slice_ops, loop_sizes elif transformation == 'dask': da_operands = [] c_chunk_size = kwargs.get('c_chunk_size') loop_index = kwargs.get('loop_index', 'c') for ia in range(len(operands)): da_ops = [] for name, ii, op in zip(self.operand_names[ia], subscripts[ia], operands[ia]): if loop_index in ii: if c_chunk_size is None: chunks = 'auto' else: ic = ii.index(loop_index) chunks = (op.shape[:ic] + (c_chunk_size,) + op.shape[ic + 1:]) da_op = da.from_array(op, chunks=chunks, name=name) else: da_op = op da_ops.append(da_op) da_operands.append(da_ops) return da_operands else: raise ValueError('unknown transformation! ({})' .format(transformation))
[docs]class ETermBase(Term): """ Reserved letters: c .. cells q .. quadrature points d-h .. DOFs axes r-z .. auxiliary axes Layout specification letters: c .. cells q .. quadrature points v .. variable component - matrix form (v, d) -> vector v*d g .. gradient component d .. local DOF (basis, node) 0 .. all material axes """ verbosity = 0 can_backend = { 'numpy' : nm, 'numpy_loop' : nm, 'numpy_qloop' : nm, 'opt_einsum' : oe, 'opt_einsum_loop' : oe, 'opt_einsum_qloop' : oe, 'jax' : jnp, 'jax_vmap' : jnp, 'dask_single' : da, 'dask_threads' : da, 'opt_einsum_dask_single' : oe and da, 'opt_einsum_dask_threads' : oe and da, } layout_letters = 'cqgvd0' def __init__(self, *args, **kwargs): Term.__init__(self, *args, **kwargs) self.set_verbosity(kwargs.get('verbosity', 0)) self.set_backend(**kwargs)
[docs] @staticmethod def function_timer(out, eval_einsum, *args): tt = Timer('') tt.start() eval_einsum(out, *args) output('eval_einsum: {} s'.format(tt.stop())) return 0
[docs] @staticmethod def function_silent(out, eval_einsum, *args): eval_einsum(out, *args) return 0
[docs] def set_verbosity(self, verbosity=None): if verbosity is not None: self.verbosity = verbosity if self.verbosity > 0: self.function = self.function_timer else: self.function = self.function_silent
[docs] def set_backend(self, backend='numpy', optimize=True, layout=None, **kwargs): if backend not in self.can_backend.keys(): raise ValueError('backend {} not in {}!' .format(self.backend, self.can_backend.keys())) if not self.can_backend[backend]: raise ValueError('backend {} is not available!'.format(backend)) if (hasattr(self, 'backend') and (backend == self.backend) and (optimize == self.optimize) and (layout == self.layout) and (kwargs == self.backend_kwargs)): return if layout is not None: if set(layout) != set(self.layout_letters): raise ValueError('layout can contain only "{}" letters! ({})' .format(self.layout_letters, layout)) self.layout = layout else: self.layout = self.layout_letters self.backend = backend self.optimize = optimize self.backend_kwargs = kwargs self.einfos = {} self.clear_cache()
[docs] def clear_cache(self): self.expr_cache = {}
[docs] def build_expression(self, texpr, *eargs, diff_var=None): timer = Timer('') timer.start() if diff_var is not None: n_add = len([arg.name for arg in eargs if (arg.kind == 'state') and (arg.name == diff_var)]) else: n_add = 1 ebuilder = ExpressionBuilder(n_add, self.expr_cache) ebuilder.build(texpr, *eargs, diff_var=diff_var) if self.verbosity: output('build expression: {} s'.format(timer.stop())) return ebuilder
[docs] def make_function(self, texpr, *args, diff_var=None): timer = Timer('') timer.start() einfo = self.einfos.setdefault(diff_var, Struct( eargs=None, ebuilder=None, paths=None, path_infos=None, eval_einsum=None, )) if einfo.eval_einsum is not None: if self.verbosity: output('einsum setup: {} s'.format(timer.stop())) return einfo.eval_einsum if einfo.eargs is None: einfo.eargs = [ ExpressionArg.from_term_arg(arg, self, self.expr_cache) for arg in args ] if einfo.ebuilder is None: einfo.ebuilder = self.build_expression(texpr, *einfo.eargs, diff_var=diff_var) n_add = einfo.ebuilder.n_add if self.backend in ('numpy', 'opt_einsum'): contract = nm.einsum if self.backend == 'numpy' else oe.contract def eval_einsum_orig(out, eshape, expressions, operands, paths): if operands[0][0].flags.c_contiguous: # This is very slow if vout layout differs from operands # layout. vout = out.reshape(eshape) contract(expressions[0], *operands[0], out=vout, optimize=paths[0]) else: aux = contract(expressions[0], *operands[0], optimize=paths[0]) out[:] += aux.reshape(out.shape) for ia in range(1, n_add): aux = contract(expressions[ia], *operands[ia], optimize=paths[ia]) out[:] += aux.reshape(out.shape) def eval_einsum0(out, eshape, expressions, operands, paths): aux = contract(expressions[0], *operands[0], optimize=paths[0]) out[:] = aux.reshape(out.shape) for ia in range(1, n_add): aux = contract(expressions[ia], *operands[ia], optimize=paths[ia]) out[:] += aux.reshape(out.shape) def eval_einsum1(out, eshape, expressions, operands, paths): out.reshape(-1)[:] = contract( expressions[0], *operands[0], optimize=paths[0], ).reshape(-1) for ia in range(1, n_add): out.reshape(-1)[:] += contract( expressions[ia], *operands[ia], optimize=paths[ia], ).reshape(-1) def eval_einsum2(out, eshape, expressions, operands, paths): out.flat = contract( expressions[0], *operands[0], optimize=paths[0], ) for ia in range(1, n_add): out.ravel()[...] += contract( expressions[ia], *operands[ia], optimize=paths[ia], ).ravel() def eval_einsum3(out, eshape, expressions, operands, paths): out.ravel()[:] = contract( expressions[0], *operands[0], optimize=paths[0], ).ravel() for ia in range(1, n_add): out.ravel()[:] += contract( expressions[ia], *operands[ia], optimize=paths[ia], ).ravel() def eval_einsum4(out, eshape, expressions, operands, paths): vout = out.reshape(eshape) contract(expressions[0], *operands[0], out=vout, optimize=paths[0]) for ia in range(1, n_add): aux = contract(expressions[ia], *operands[ia], optimize=paths[ia]) out[:] += aux.reshape(out.shape) eval_fun = self.backend_kwargs.get('eval_fun', 'eval_einsum0') eval_einsum = locals()[eval_fun] elif self.backend in ('numpy_loop', 'opt_einsum_loop'): contract = nm.einsum if self.backend == 'numpy_loop' else oe.contract def eval_einsum(out, eshape, expressions, all_slice_ops, paths): n_cell = out.shape[0] vout = out.reshape(eshape) slice_ops = all_slice_ops[0] if vout.ndim > 1: for ic in range(n_cell): ops = slice_ops(ic) contract(expressions[0], *ops, out=vout[ic], optimize=paths[0]) else: # vout[ic] can be scalar in eval mode. for ic in range(n_cell): ops = slice_ops(ic) vout[ic] = contract(expressions[0], *ops, optimize=paths[0]) for ia in range(1, n_add): slice_ops = all_slice_ops[ia] for ic in range(n_cell): ops = slice_ops(ic) vout[ic] += contract(expressions[ia], *ops, optimize=paths[ia]) elif self.backend in ('numpy_qloop', 'opt_einsum_qloop'): contract = (nm.einsum if self.backend == 'numpy_qloop' else oe.contract) def eval_einsum(out, eshape, expressions, all_slice_ops, loop_sizes, paths): n_qp = loop_sizes[0] vout = out.reshape(eshape) slice_ops = all_slice_ops[0] ops = slice_ops(0) vout[:] = contract(expressions[0], *ops, optimize=paths[0]) for iq in range(1, n_qp): ops = slice_ops(iq) vout[:] += contract(expressions[0], *ops, optimize=paths[0]) for ia in range(1, n_add): n_qp = loop_sizes[ia] slice_ops = all_slice_ops[ia] for iq in range(n_qp): ops = slice_ops(iq) vout[:] += contract(expressions[ia], *ops, optimize=paths[ia]) elif self.backend == 'jax': @jax.partial(jax.jit, static_argnums=(0, 1, 2)) def _eval_einsum(expressions, paths, n_add, operands): val = jnp.einsum(expressions[0], *operands[0], optimize=paths[0]) for ia in range(1, n_add): val += jnp.einsum(expressions[ia], *operands[ia], optimize=paths[ia]) return val def eval_einsum(out, eshape, expressions, operands, paths): aux = _eval_einsum(expressions, paths, n_add, operands) out[:] = nm.asarray(aux.reshape(out.shape)) elif self.backend == 'jax_vmap': def _eval_einsum_cell(expressions, paths, n_add, operands): val = jnp.einsum(expressions[0], *operands[0], optimize=paths[0]) for ia in range(1, n_add): val += jnp.einsum(expressions[ia], *operands[ia], optimize=paths[ia]) return val def eval_einsum(out, vmap_eval_cell, eshape, expressions, operands, paths): aux = vmap_eval_cell(expressions, paths, n_add, operands) out[:] = nm.asarray(aux.reshape(out.shape)) eval_einsum = (eval_einsum, _eval_einsum_cell) elif self.backend.startswith('dask'): scheduler = {'dask_single' : 'single-threaded', 'dask_threads' : 'threads'}[self.backend] def eval_einsum(out, eshape, expressions, operands, paths): _out = da.einsum(expressions[0], *operands[0], optimize=paths[0]) for ia in range(1, n_add): aux = da.einsum(expressions[ia], *operands[ia], optimize=paths[ia]) _out += aux out[:] = _out.compute(scheduler=scheduler).reshape(out.shape) elif self.backend.startswith('opt_einsum_dask'): scheduler = {'opt_einsum_dask_single' : 'single-threaded', 'opt_einsum_dask_threads' : 'threads'}[self.backend] def eval_einsum(out, eshape, expressions, operands, paths): _out = oe.contract(expressions[0], *operands[0], optimize=paths[0], backend='dask') for ia in range(1, n_add): aux = oe.contract(expressions[ia], *operands[ia], optimize=paths[ia], backend='dask') _out += aux out[:] = _out.compute(scheduler=scheduler).reshape(out.shape) else: raise ValueError('unsupported backend! ({})'.format(self.backend)) einfo.eval_einsum = eval_einsum if self.verbosity: output('einsum setup: {} s'.format(timer.stop())) return eval_einsum
[docs] def get_operands(self, diff_var): einfo = self.einfos[diff_var] return get_einsum_ops(einfo.eargs, einfo.ebuilder, self.expr_cache)
[docs] def get_paths(self, expressions, operands): memory_limit = self.backend_kwargs.get('memory_limit') if ('numpy' in self.backend) or self.backend.startswith('dask'): optimize = (self.optimize if memory_limit is None else (self.optimize, memory_limit)) paths, path_infos = zip(*[nm.einsum_path( expressions[ia], *operands[ia], optimize=optimize, ) for ia in range(len(operands))]) elif 'opt_einsum' in self.backend: paths, path_infos = zip(*[oe.contract_path( expressions[ia], *operands[ia], optimize=self.optimize, memory_limit=memory_limit, ) for ia in range(len(operands))]) elif 'jax' in self.backend: paths, path_infos = [], [] for ia in range(len(operands)): path, info = jnp.einsum_path( expressions[ia], *operands[ia], optimize=self.optimize, ) paths.append(tuple(path)) path_infos.append(info) paths = tuple(paths) path_infos = tuple(path_infos) else: raise ValueError('unsupported backend! ({})'.format(self.backend)) return paths, path_infos
[docs] def get_fargs(self, *args, **kwargs): mode, term_mode, diff_var = args[-3:] eval_einsum = self.get_function(*args, **kwargs) operands = self.get_operands(diff_var) einfo = self.einfos[diff_var] ebuilder = einfo.ebuilder eshape = get_output_shape(ebuilder.out_subscripts[0], ebuilder.subscripts[0], operands[0]) out = [eval_einsum, eshape] subscripts, operands = ebuilder.apply_layout( self.layout, operands, verbosity=self.verbosity, ) self.parsed_expressions = ebuilder.get_expressions(subscripts) if self.verbosity: output('parsed expressions:', self.parsed_expressions) cloop = self.backend in ('numpy_loop', 'opt_einsum_loop', 'jax_vmap') qloop = self.backend in ('numpy_qloop', 'opt_einsum_qloop') if cloop or qloop: loop_index = 'c' if cloop else 'q' transform = ebuilder.transform(subscripts, operands, transformation='loop', loop_index=loop_index) expressions, poperands, all_slice_ops, loop_sizes = transform if self.backend == 'jax_vmap': all_ics = [get_loop_indices(subs, loop_index) for subs in subscripts] vms = (None, None, None, all_ics) vmap_eval_cell = jax.jit(jax.vmap(eval_einsum[1], vms, 0), static_argnums=(0, 1, 2)) out += [expressions, operands] out[:1] = [eval_einsum[0], vmap_eval_cell] else: out += [expressions, all_slice_ops] if qloop: out.append(loop_sizes) elif (self.backend.startswith('dask') or self.backend.startswith('opt_einsum_dask')): c_chunk_size = self.backend_kwargs.get('c_chunk_size') da_operands = ebuilder.transform(subscripts, operands, transformation='dask', c_chunk_size=c_chunk_size) poperands = operands expressions = self.parsed_expressions out += [expressions, da_operands] else: poperands = operands expressions = self.parsed_expressions out += [expressions, operands] if einfo.paths is None: if self.verbosity > 1: ebuilder.print_shapes(subscripts, operands) einfo.paths, einfo.path_infos = self.get_paths( expressions, poperands, ) if self.verbosity > 2: for path, path_info in zip(einfo.paths, einfo.path_infos): output('path:', path) output(path_info) out += [einfo.paths] return out
[docs] def get_eval_shape(self, *args, **kwargs): mode, term_mode, diff_var = args[-3:] if diff_var is not None: raise ValueError('cannot differentiate in {} mode!' .format(mode)) self.get_function(*args, **kwargs) operands = self.get_operands(diff_var) ebuilder = self.einfos[diff_var].ebuilder out_shape = get_output_shape(ebuilder.out_subscripts[0], ebuilder.subscripts[0], operands[0]) dtype = nm.find_common_type([op.dtype for op in operands[0]], []) return out_shape, dtype
[docs] def get_normals(self, arg): normals = self.get_mapping(arg)[0].normal if normals is not None: normals = ExpressionArg(name='n({})'.format(arg.name), arg=normals[..., 0], kind='ndarray') return normals
[docs]class EIntegrateOperatorTerm(ETermBase): r""" Volume and surface integral of a test function weighted by a scalar function :math:`c`. :Definition: .. math:: \int_\Omega q \mbox{ or } \int_\Omega c q :Arguments: - material : :math:`c` (optional) - virtual : :math:`q` """ name = 'de_integrate' arg_types = ('opt_material', 'virtual') arg_shapes = [{'opt_material' : '1, 1', 'virtual' : (1, None)}, {'opt_material' : None}] integration = 'by_region'
[docs] def get_function(self, mat, virtual, mode=None, term_mode=None, diff_var=None, **kwargs): if mat is None: fun = self.make_function( '0', virtual, diff_var=diff_var, ) else: fun = self.make_function( '00,0', mat, virtual, diff_var=diff_var, ) return fun
[docs]class ELaplaceTerm(ETermBase): r""" Laplace term with :math:`c` coefficient. Can be evaluated. Can use derivatives. :Definition: .. math:: \int_{\Omega} c \nabla q \cdot \nabla p \mbox{ , } \int_{\Omega} c \nabla \bar{p} \cdot \nabla r :Arguments 1: - material : :math:`c` - virtual : :math:`q` - state : :math:`p` :Arguments 2: - material : :math:`c` - parameter_1 : :math:`\bar{p}` - parameter_2 : :math:`r` """ name = 'de_laplace' arg_types = (('opt_material', 'virtual', 'state'), ('opt_material', 'parameter_1', 'parameter_2')) arg_shapes = [{'opt_material' : '1, 1', 'virtual' : (1, 'state'), 'state' : 1, 'parameter_1' : 1, 'parameter_2' : 1}, {'opt_material' : None}] modes = ('weak', 'eval')
[docs] def get_function(self, mat, virtual, state, mode=None, term_mode=None, diff_var=None, **kwargs): if mat is None: fun = self.make_function( '0.j,0.j', virtual, state, diff_var=diff_var, ) else: fun = self.make_function( '00,0.j,0.j', mat, virtual, state, diff_var=diff_var, ) return fun
[docs]class EDotTerm(ETermBase): r""" Volume and surface :math:`L^2(\Omega)` weighted dot product for both scalar and vector fields. Can be evaluated. Can use derivatives. :Definition: .. math:: \int_{\cal{D}} q p \mbox{ , } \int_{\cal{D}} \ul{v} \cdot \ul{u} \mbox{ , } \int_{\cal{D}} p r \mbox{ , } \int_{\cal{D}} \ul{u} \cdot \ul{w} \\ \int_{\cal{D}} c q p \mbox{ , } \int_{\cal{D}} c \ul{v} \cdot \ul{u} \mbox{ , } \int_{\cal{D}} c p r \mbox{ , } \int_{\cal{D}} c \ul{u} \cdot \ul{w} \\ \int_{\cal{D}} \ul{v} \cdot \ull{M} \cdot \ul{u} \mbox{ , } \int_{\cal{D}} \ul{u} \cdot \ull{M} \cdot \ul{w} :Arguments 1: - material : :math:`c` or :math:`\ull{M}` (optional) - virtual : :math:`q` or :math:`\ul{v}` - state : :math:`p` or :math:`\ul{u}` :Arguments 2: - material : :math:`c` or :math:`\ull{M}` (optional) - parameter_1 : :math:`p` or :math:`\ul{u}` - parameter_2 : :math:`r` or :math:`\ul{w}` """ name = 'de_dot' arg_types = (('opt_material', 'virtual', 'state'), ('opt_material', 'parameter_1', 'parameter_2')) arg_shapes = [{'opt_material' : '1, 1', 'virtual' : (1, 'state'), 'state' : 1, 'parameter_1' : 1, 'parameter_2' : 1}, {'opt_material' : None}, {'opt_material' : '1, 1', 'virtual' : ('D', 'state'), 'state' : 'D', 'parameter_1' : 'D', 'parameter_2' : 'D'}, {'opt_material' : 'D, D'}, {'opt_material' : None}] modes = ('weak', 'eval') integration = 'by_region'
[docs] def get_function(self, mat, virtual, state, mode=None, term_mode=None, diff_var=None, **kwargs): if mat is None: fun = self.make_function( 'i,i', virtual, state, diff_var=diff_var, ) else: if mat.shape[-1] > 1: fun = self.make_function( 'ij,i,j', mat, virtual, state, diff_var=diff_var, ) else: fun = self.make_function( '00,i,i', mat, virtual, state, diff_var=diff_var, ) return fun
[docs]class EScalarDotMGradScalarTerm(ETermBase): r""" Volume dot product of a scalar gradient dotted with a material vector with a scalar. :Definition: .. math:: \int_{\Omega} q \ul{y} \cdot \nabla p \mbox{ , } \int_{\Omega} p \ul{y} \cdot \nabla q :Arguments 1: - material : :math:`\ul{y}` - virtual : :math:`q` - state : :math:`p` :Arguments 2: - material : :math:`\ul{y}` - state : :math:`p` - virtual : :math:`q` """ name = 'de_s_dot_mgrad_s' arg_types = (('material', 'virtual', 'state'), ('material', 'state', 'virtual')) arg_shapes = [{'material' : 'D, 1', 'virtual/grad_state' : (1, None), 'state/grad_state' : 1, 'virtual/grad_virtual' : (1, None), 'state/grad_virtual' : 1}] modes = ('grad_state', 'grad_virtual')
[docs] def get_function(self, mat, var1, var2, mode=None, term_mode=None, diff_var=None, **kwargs): return self.make_function( 'i0,0,0.i', mat, var1, var2, diff_var=diff_var, )
[docs]class ENonPenetrationPenaltyTerm(ETermBase): r""" Non-penetration condition in the weak sense using a penalty. :Definition: .. math:: \int_{\Gamma} c (\ul{n} \cdot \ul{v}) (\ul{n} \cdot \ul{u}) :Arguments: - material : :math:`c` - virtual : :math:`\ul{v}` - state : :math:`\ul{u}` """ name = 'de_non_penetration_p' arg_types = ('material', 'virtual', 'state') arg_shapes = {'material' : '1, 1', 'virtual' : ('D', 'state'), 'state' : 'D'} integration = 'surface'
[docs] def get_function(self, mat, virtual, state, mode=None, term_mode=None, diff_var=None, **kwargs): normals = self.get_normals(state) return self.make_function( '00,i,i,j,j', mat, virtual, normals, state, normals, diff_var=diff_var, )
[docs]class EDivGradTerm(ETermBase): r""" Vector field diffusion term. :Definition: .. math:: \int_{\Omega} \nu\ \nabla \ul{v} : \nabla \ul{u} \mbox{ , } \int_{\Omega} \nu\ \nabla \ul{u} : \nabla \ul{w} \\ \int_{\Omega} \nabla \ul{v} : \nabla \ul{u} \mbox{ , } \int_{\Omega} \nabla \ul{u} : \nabla \ul{w} :Arguments 1: - material : :math:`\nu` (viscosity, optional) - virtual : :math:`\ul{v}` - state : :math:`\ul{u}` :Arguments 2: - material : :math:`\nu` (viscosity, optional) - parameter_1 : :math:`\ul{u}` - parameter_2 : :math:`\ul{w}` """ name = 'de_div_grad' arg_types = (('opt_material', 'virtual', 'state'), ('opt_material', 'parameter_1', 'parameter_2')) arg_shapes = [{'opt_material' : '1, 1', 'virtual' : ('D', 'state'), 'state' : 'D', 'parameter_1' : 'D', 'parameter_2' : 'D'}, {'opt_material' : None}] modes = ('weak', 'eval')
[docs] def get_function(self, mat, virtual, state, mode=None, term_mode=None, diff_var=None, **kwargs): if mat is None: fun = self.make_function( 'i.j,i.j', virtual, state, diff_var=diff_var, ) else: fun = self.make_function( '00,i.j,i.j', mat, virtual, state, diff_var=diff_var, ) return fun
[docs]class EConvectTerm(ETermBase): r""" Nonlinear convective term. :Definition: .. math:: \int_{\Omega} ((\ul{u} \cdot \nabla) \ul{u}) \cdot \ul{v} \mbox{ , } \int_{\Omega} ((\ul{w} \cdot \nabla) \ul{w}) \cdot \bar{\ul{u}} :Arguments 1: - virtual : :math:`\ul{v}` - state : :math:`\ul{u}` :Arguments 2: - parameter_1 : :math:`\bar{\ul{u}}` - parameter_2 : :math:`\ul{w}` """ name = 'de_convect' arg_types = (('virtual', 'state'), ('parameter_1', 'parameter_2')) arg_shapes = {'virtual' : ('D', 'state'), 'state' : 'D', 'parameter_1' : 'D', 'parameter_2' : 'D'} modes = ('weak', 'eval')
[docs] def get_function(self, virtual, state, mode=None, term_mode=None, diff_var=None, **kwargs): return self.make_function( 'i,i.j,j', virtual, state, state, diff_var=diff_var, )
[docs]class EDivTerm(ETermBase): r""" Weighted divergence term. :Definition: .. math:: \int_{\Omega} \nabla \cdot \ul{v} \mbox { , } \int_{\Omega} \nabla \cdot \ul{u} \\ \int_{\Omega} c \nabla \cdot \ul{v} \mbox { , } \int_{\Omega} c \nabla \cdot \ul{u} :Arguments 1: - material : :math:`c` (optional) - virtual : :math:`\ul{v}` :Arguments 2: - material : :math:`c` (optional) - parameter : :math:`\ul{u}` """ name = 'de_div' arg_types = (('opt_material', 'virtual'), ('opt_material', 'parameter'),) arg_shapes = [{'opt_material' : '1, 1', 'virtual' : ('D', None), 'parameter' : 'D'}, {'opt_material' : None}] modes = ('weak', 'eval')
[docs] def get_function(self, mat, virtual, mode=None, term_mode=None, diff_var=None, **kwargs): if mat is None: fun = self.make_function( 'i.i', virtual, diff_var=diff_var, ) else: fun = self.make_function( '00,i.i', mat, virtual, diff_var=diff_var, ) return fun
[docs]class EStokesTerm(ETermBase): r""" Stokes problem coupling term. Corresponds to weak forms of gradient and divergence terms. :Definition: .. math:: \int_{\Omega} p\ \nabla \cdot \ul{v} \mbox{ , } \int_{\Omega} q\ \nabla \cdot \ul{u} \mbox{ or } \int_{\Omega} c\ p\ \nabla \cdot \ul{v} \mbox{ , } \int_{\Omega} c\ q\ \nabla \cdot \ul{u} \\ \int_{\Omega} r\ \nabla \cdot \ul{w} \mbox{ , } \int_{\Omega} c r\ \nabla \cdot \ul{w} :Arguments 1: - material : :math:`c` (optional) - virtual : :math:`\ul{v}` - state : :math:`p` :Arguments 2: - material : :math:`c` (optional) - state : :math:`\ul{u}` - virtual : :math:`q` :Arguments 3: - material : :math:`c` (optional) - parameter_v : :math:`\ul{u}` - parameter_s : :math:`p` """ name = 'de_stokes' arg_types = (('opt_material', 'virtual', 'state'), ('opt_material', 'state', 'virtual'), ('opt_material', 'parameter_v', 'parameter_s')) arg_shapes = [{'opt_material' : '1, 1', 'virtual/grad' : ('D', None), 'state/grad' : 1, 'virtual/div' : (1, None), 'state/div' : 'D', 'parameter_v' : 'D', 'parameter_s' : 1}, {'opt_material' : None}] modes = ('grad', 'div', 'eval')
[docs] def get_function(self, coef, vvar, svar, mode=None, term_mode=None, diff_var=None, **kwargs): if coef is None: fun = self.make_function( 'i.i,0', vvar, svar, diff_var=diff_var, ) else: fun = self.make_function( '00,i.i,0', coef, vvar, svar, diff_var=diff_var, ) return fun
[docs]class ELinearElasticTerm(ETermBase): r""" General linear elasticity term, with :math:`D_{ijkl}` given in the usual matrix form exploiting symmetry: in 3D it is :math:`6\times6` with the indices ordered as :math:`[11, 22, 33, 12, 13, 23]`, in 2D it is :math:`3\times3` with the indices ordered as :math:`[11, 22, 12]`. :Definition: .. math:: \int_{\Omega} D_{ijkl}\ e_{ij}(\ul{v}) e_{kl}(\ul{u}) \mbox{ , } \int_{\Omega} D_{ijkl}\ e_{ij}(\ul{w}) e_{kl}(\ul{u}) :Arguments 1: - material : :math:`D_{ijkl}` - virtual : :math:`\ul{v}` - state : :math:`\ul{u}` :Arguments 2: - material : :math:`D_{ijkl}` - parameter_1 : :math:`\ul{w}` - parameter_2 : :math:`\ul{u}` """ name = 'de_lin_elastic' arg_types = (('material', 'virtual', 'state'), ('material', 'parameter_1', 'parameter_2')) arg_shapes = {'material' : 'S, S', 'virtual' : ('D', 'state'), 'state' : 'D', 'parameter_1' : 'D', 'parameter_2' : 'D'} modes = ('weak', 'eval')
[docs] def get_function(self, mat, virtual, state, mode=None, term_mode=None, diff_var=None, **kwargs): return self.make_function( 'IK,s(i:j)->I,s(k:l)->K', mat, virtual, state, diff_var=diff_var, )
[docs]class ECauchyStressTerm(ETermBase): r""" Evaluate Cauchy stress tensor. It is given in the usual vector form exploiting symmetry: in 3D it has 6 components with the indices ordered as :math:`[11, 22, 33, 12, 13, 23]`, in 2D it has 3 components with the indices ordered as :math:`[11, 22, 12]`. :Definition: .. math:: \int_{\Omega} D_{ijkl} e_{kl}(\ul{w}) :Arguments: - material : :math:`D_{ijkl}` - parameter : :math:`\ul{w}` """ name = 'de_cauchy_stress' arg_types = ('material', 'parameter') arg_shapes = {'material' : 'S, S', 'parameter' : 'D'}
[docs] def get_function(self, mat, parameter, mode=None, term_mode=None, diff_var=None, **kwargs): return self.make_function( 'IK,s(k:l)->K', mat, parameter, diff_var=diff_var, )