diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py index e236e032583451b2ef4866c85734117b533955e7..86c3a5cbc4618becf4bcee768e43ca1cab4c995c 100644 --- a/src/pystencils_reco/_assignment_collection.py +++ b/src/pystencils_reco/_assignment_collection.py @@ -105,7 +105,7 @@ class AssignmentCollection(pystencils.AssignmentCollection): self.args = [] self.kwargs = {} self._autodiff = None - self.kernel = None + self._kernel = None # @property # def reproducible_hash(self): @@ -118,6 +118,15 @@ class AssignmentCollection(pystencils.AssignmentCollection): # def __getstate__(self): # return self.reproducible_hash + @property + def kernel(self): + if not self._kernel: + self.compile() + return self._kernel + + def __call__(self, *args, **kwargs): + return self.kernel(*args, **kwargs) + def compile(self, target=None, *args, **kwargs): """Convenience wrapper for pystencils.create_kernel(...).compile() See :func: ~pystencils.create_kernel @@ -157,6 +166,7 @@ class AssignmentCollection(pystencils.AssignmentCollection): else: kernel.__call__ = partial(kernel, **self.kwargs) + self._kernel = kernel return kernel def backward(self): @@ -177,18 +187,18 @@ class AssignmentCollection(pystencils.AssignmentCollection): def _create_ml_op(self, backend, target, **kwargs): if not target: target = 'gpu' - constant_field_names = [f for f, t in kwargs.items() - if hasattr(t, 'requires_grad') and not t.requires_grad] - constant_fields = {f for f in self.free_fields if f.name in constant_field_names} + # constant_field_names = [f for f, t in kwargs.items() + # if hasattr(t, 'requires_grad') and not t.requires_grad] + # constant_fields = {f for f in self.free_fields if f.name in constant_field_names} for n in [f for f, t in kwargs.items() if hasattr(t, 'requires_grad')]: kwargs.pop(n) if not self._autodiff: if hasattr(self, '_create_autodiff'): - self._create_autodiff(constant_fields, **kwargs) + self._create_autodiff(**kwargs) else: - self._autodiff = _create_autodiff(self, constant_fields, **kwargs) + self._autodiff = _create_autodiff(self, **kwargs) op = self._autodiff.create_tensorflow_op(backend=backend, use_cuda=(target == 'gpu'))