Skip to content
Snippets Groups Projects
Commit b1dc16e3 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Un-prefix: functions from pystencils_autodiff._file_io

parent bd5b5af0
Branches
Tags
No related merge requests found
...@@ -10,15 +10,15 @@ ...@@ -10,15 +10,15 @@
import jinja2 import jinja2
def _read_template_from_file(file): def read_template_from_file(file):
return jinja2.Template(_read_file(file)) return jinja2.Template(read_file(file))
def _read_file(file): def read_file(file):
with open(file, 'r') as f: with open(file, 'r') as f:
return f.read() return f.read()
def _write_file(filename, content): def write_file(filename, content):
with open(filename, 'w') as f: with open(filename, 'w') as f:
return f.write(content) return f.write(content)
...@@ -22,12 +22,12 @@ except ImportError: ...@@ -22,12 +22,12 @@ except ImportError:
pass pass
def _read_file(file): def read_file(file):
with open(file, 'r') as f: with open(file, 'r') as f:
return f.read() return f.read()
def _write_file(filename, content): def write_file(filename, content):
with open(filename, 'w') as f: with open(filename, 'w') as f:
return f.write(content) return f.write(content)
...@@ -106,22 +106,22 @@ def generate_torch(destination_folder, ...@@ -106,22 +106,22 @@ def generate_torch(destination_folder,
} }
if is_cuda: if is_cuda:
template_string_cpp = _read_file(join(dirname(__file__), template_string_cpp = read_file(join(dirname(__file__),
'torch_native_cuda.tmpl.cpp')) 'torch_native_cuda.tmpl.cpp'))
template = jinja2.Template(template_string_cpp) template = jinja2.Template(template_string_cpp)
output = template.render(render_dict) output = template.render(render_dict)
_write_file(join(destination_folder, operation_string + '.cpp'), output) write_file(join(destination_folder, operation_string + '.cpp'), output)
template_string = _read_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cu')) template_string = read_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cu'))
template = jinja2.Template(template_string) template = jinja2.Template(template_string)
output = template.render(render_dict) output = template.render(render_dict)
_write_file(join(destination_folder, operation_string + '.cu'), output) write_file(join(destination_folder, operation_string + '.cu'), output)
else: else:
template_string_cpp = _read_file(join(dirname(__file__), template_string_cpp = read_file(join(dirname(__file__),
'torch_native_cpu.tmpl.cpp')) 'torch_native_cpu.tmpl.cpp'))
template = jinja2.Template(template_string_cpp) template = jinja2.Template(template_string_cpp)
output = template.render(render_dict) output = template.render(render_dict)
_write_file(join(destination_folder, operation_string + '.cpp'), output) write_file(join(destination_folder, operation_string + '.cpp'), output)
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
compiled_operation = load(operation_string, required_files, verbose=True, compiled_operation = load(operation_string, required_files, verbose=True,
......
...@@ -12,7 +12,7 @@ from collections.abc import Iterable ...@@ -12,7 +12,7 @@ from collections.abc import Iterable
from os.path import dirname, join from os.path import dirname, join
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils_autodiff._file_io import _read_template_from_file from pystencils_autodiff._file_io import read_template_from_file
from pystencils_autodiff.backends.python_bindings import ( from pystencils_autodiff.backends.python_bindings import (
PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping, PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping,
TensorflowPythonBindings, TorchPythonBindings) TensorflowPythonBindings, TorchPythonBindings)
...@@ -61,7 +61,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass): ...@@ -61,7 +61,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
class TorchModule(JinjaCppFile): class TorchModule(JinjaCppFile):
TEMPLATE = _read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp')) TEMPLATE = read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp'))
DESTRUCTURING_CLASS = TorchTensorDestructuring DESTRUCTURING_CLASS = TorchTensorDestructuring
PYTHON_BINDINGS_CLASS = TorchPythonBindings PYTHON_BINDINGS_CLASS = TorchPythonBindings
PYTHON_FUNCTION_WRAPPING_CLASS = PybindFunctionWrapping PYTHON_FUNCTION_WRAPPING_CLASS = PybindFunctionWrapping
...@@ -103,7 +103,7 @@ class TensorflowModule(TorchModule): ...@@ -103,7 +103,7 @@ class TensorflowModule(TorchModule):
:param backward_kernel_ast: :param backward_kernel_ast:
""" """
if use_cuda: if use_cuda:
self.TEMPLATE = _read_template_from_file(join(dirname(__file__), 'tensorflow.cuda.tmpl.cu')) self.TEMPLATE = read_template_from_file(join(dirname(__file__), 'tensorflow.cuda.tmpl.cu'))
super().__init__(module_name, kernel_asts) super().__init__(module_name, kernel_asts)
......
...@@ -13,7 +13,7 @@ import tempfile ...@@ -13,7 +13,7 @@ import tempfile
import pystencils import pystencils
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import _write_file from pystencils_autodiff._file_io import write_file
from pystencils_autodiff.backends.astnodes import TorchModule from pystencils_autodiff.backends.astnodes import TorchModule
torch = pytest.importorskip('torch') torch = pytest.importorskip('torch')
...@@ -66,7 +66,7 @@ def test_torch_native_compilation_cpu(): ...@@ -66,7 +66,7 @@ def test_torch_native_compilation_cpu():
temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp') temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp')
print(temp_file.name) print(temp_file.name)
_write_file(temp_file.name, str(module)) write_file(temp_file.name, str(module))
torch_extension = load(module_name, [temp_file.name]) torch_extension = load(module_name, [temp_file.name])
assert torch_extension is not None assert torch_extension is not None
assert 'call_forward' in dir(torch_extension) assert 'call_forward' in dir(torch_extension)
...@@ -99,7 +99,7 @@ def test_torch_native_compilation_gpu(): ...@@ -99,7 +99,7 @@ def test_torch_native_compilation_gpu():
temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp') temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp')
print(temp_file.name) print(temp_file.name)
_write_file(temp_file.name, str(module)) write_file(temp_file.name, str(module))
torch_extension = load(module_name, [temp_file.name]) torch_extension = load(module_name, [temp_file.name])
assert torch_extension is not None assert torch_extension is not None
assert 'call_forward' in dir(torch_extension) assert 'call_forward' in dir(torch_extension)
......
...@@ -20,7 +20,7 @@ import sympy ...@@ -20,7 +20,7 @@ import sympy
import pystencils import pystencils
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import _write_file from pystencils_autodiff._file_io import write_file
from pystencils_autodiff.backends.astnodes import TensorflowModule from pystencils_autodiff.backends.astnodes import TensorflowModule
...@@ -71,8 +71,8 @@ def test_native_tensorflow_compilation_cpu(): ...@@ -71,8 +71,8 @@ def test_native_tensorflow_compilation_cpu():
temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp') temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp')
print(temp_file.name) print(temp_file.name)
_write_file(temp_file.name, str(module)) write_file(temp_file.name, str(module))
_write_file('/tmp/foo.cpp', str(module)) write_file('/tmp/foo.cpp', str(module))
command = ['c++', '-fPIC', temp_file.name, '-O2', '-shared', command = ['c++', '-fPIC', temp_file.name, '-O2', '-shared',
'-o', 'foo.so'] + compile_flags + link_flags + extra_flags '-o', 'foo.so'] + compile_flags + link_flags + extra_flags
...@@ -115,7 +115,7 @@ def test_native_tensorflow_compilation_gpu(): ...@@ -115,7 +115,7 @@ def test_native_tensorflow_compilation_gpu():
temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp') temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp')
print(temp_file.name) print(temp_file.name)
_write_file(temp_file.name, str(module)) write_file(temp_file.name, str(module))
# on my machine g++-6 and clang-7 are working # on my machine g++-6 and clang-7 are working
command = ['nvcc', command = ['nvcc',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment