Skip to content
Snippets Groups Projects
Commit 0cccbe5e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Lint

parent 04ae2726
Branches
Tags
No related merge requests found
Pipeline #17928 failed
...@@ -106,8 +106,7 @@ def generate_torch(destination_folder, ...@@ -106,8 +106,7 @@ def generate_torch(destination_folder,
} }
if is_cuda: if is_cuda:
template_string_cpp = read_file(join(dirname(__file__), template_string_cpp = read_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cpp'))
'torch_native_cuda.tmpl.cpp'))
template = jinja2.Template(template_string_cpp) template = jinja2.Template(template_string_cpp)
output = template.render(render_dict) output = template.render(render_dict)
write_file(join(destination_folder, operation_string + '.cpp'), output) write_file(join(destination_folder, operation_string + '.cpp'), output)
...@@ -117,8 +116,7 @@ def generate_torch(destination_folder, ...@@ -117,8 +116,7 @@ def generate_torch(destination_folder,
output = template.render(render_dict) output = template.render(render_dict)
write_file(join(destination_folder, operation_string + '.cu'), output) write_file(join(destination_folder, operation_string + '.cu'), output)
else: else:
template_string_cpp = read_file(join(dirname(__file__), template_string_cpp = read_file(join(dirname(__file__), 'torch_native_cpu.tmpl.cpp'))
'torch_native_cpu.tmpl.cpp'))
template = jinja2.Template(template_string_cpp) template = jinja2.Template(template_string_cpp)
output = template.render(render_dict) output = template.render(render_dict)
write_file(join(destination_folder, operation_string + '.cpp'), output) write_file(join(destination_folder, operation_string + '.cpp'), output)
......
...@@ -12,7 +12,7 @@ import itertools ...@@ -12,7 +12,7 @@ import itertools
import jinja2 import jinja2
import stringcase import stringcase
from pystencils.astnodes import KernelFunction, Node from pystencils.astnodes import KernelFunction
from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
import json import json
import subprocess import subprocess
import sysconfig import sysconfig
from itertools import chain
from os.path import exists, join from os.path import exists, join
from tqdm import tqdm from tqdm import tqdm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment