diff --git a/docs/source/tutorials/01_tutorial_getting_started.ipynb b/docs/source/tutorials/01_tutorial_getting_started.ipynb index baa3aac6ac9ad5a42db9244ff03d5f34e246530f..5ce765fcea33088463c5e5274cab8fb5654f6229 100644 --- a/docs/source/tutorials/01_tutorial_getting_started.ipynb +++ b/docs/source/tutorials/01_tutorial_getting_started.ipynb @@ -1140,7 +1140,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1270,8 +1270,8 @@ "source": [ "ast = ps.create_kernel(\n", " update_rule,\n", - " cpu_optim = ps.config.CpuOptimConfig(\n", - " openmp=ps.config.OpenMpConfig(num_threads=2))\n", + " cpu_optim = ps.CpuOptimConfig(\n", + " openmp=ps.OpenMpConfig(num_threads=2))\n", " )\n", "\n", "ps.show_code(ast)" @@ -1472,7 +1472,7 @@ "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 0fd49b248fd4e260546bea9b1727f93a4c951463..f7836ea794080a02a26dd8b29bf2d010ff5ab389 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -395,7 +395,7 @@ class DefaultKernelCreationDriver: req_headers |= self._platform.required_headers req_headers |= self._ctx.required_headers return req_headers - + def create_cpu_kernel_function( ctx: KernelCreationContext, @@ -410,9 +410,7 @@ def create_cpu_kernel_function( params = _get_function_params(ctx, undef_symbols) req_headers = _get_headers(ctx, platform, body) - kfunc = Kernel( - body, target_spec, function_name, params, req_headers, jit - ) + kfunc = Kernel(body, target_spec, function_name, params, req_headers, jit) kfunc.metadata.update(ctx.metadata) return kfunc @@ -421,14 +419,16 @@ def create_gpu_kernel_function( ctx: KernelCreationContext, platform: Platform, body: PsBlock, - threads_range: GpuThreadsRange, + threads_range: GpuThreadsRange | None, function_name: str, target_spec: Target, jit: JitBase, ): undef_symbols = collect_undefined_symbols(body) - for threads in threads_range.num_work_items: - undef_symbols |= collect_undefined_symbols(threads) + + if threads_range is not None: + for threads in threads_range.num_work_items: + undef_symbols |= collect_undefined_symbols(threads) params = _get_function_params(ctx, undef_symbols) req_headers = _get_headers(ctx, platform, body) diff --git a/src/pystencils/runhelper/db.py b/src/pystencils/runhelper/db.py index dd413a5e405771822d36611d1068936b74ee334c..e199829584c65ea096db1fc6c8e0192e44805705 100644 --- a/src/pystencils/runhelper/db.py +++ b/src/pystencils/runhelper/db.py @@ -8,7 +8,7 @@ import six from blitzdb.backends.file.backend import serializer_classes from blitzdb.backends.file.utils import JsonEncoder -from pystencils.backend.jit.legacy_cpu import get_compiler_config +from pystencils.jit.legacy_cpu import get_compiler_config from pystencils import CreateKernelConfig, Target, Field import json