diff --git a/Makefile b/Makefile
index 1bcaa4a8ea131e1b2858efeb3b8dca26f360d9da..60e3a0a1d0a143d0e9cd1734e293f2844b405129 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
 .PHONY: all build clean lj_ns
 
-all: build lj_ns clean
+all: build lj_ns
 	@echo "Everything was done!"
 
 build:
@@ -10,8 +10,14 @@ build:
 lj_ns:
 	@echo "Generating and compiling CPP for Lennard-Jones example..."
 	python3 examples/lj_func.py
+
+# Targets
+cpu: build lj_ns
 	g++ -o lj_ns lj_ns.cpp
 
+gpu: build lj_ns
+	nvcc -o lj_ns lj_ns.cu
+
 clean:
 	@echo "Cleaning..."
 	rm -rf build lj_ns lj_ns.cpp dist pairs.egg-info functions functions.pdf
diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py
index 75615c3f6bad7c19758e874342791649fed055d1..d0e878d7d3fc05e7d9e80a22782a53e002a18dcc 100644
--- a/src/pairs/__init__.py
+++ b/src/pairs/__init__.py
@@ -4,7 +4,7 @@ from pairs.sim.simulation import Simulation
 
 
 def simulation(ref, dims=3, timesteps=100, debug=False):
-    return Simulation(CGen(f"{ref}.cpp", debug), dims, timesteps)
+    return Simulation(CGen(ref, debug), dims, timesteps)
 
 def target_cpu():
     return Target(Target.Backend_CPP, Target.Feature_CPU)
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index f11a8ceb5fb662cf9a607061da1bdaae5fcfce16..e8772a98d6333132e834a803d6277b9049fa6e3c 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -26,11 +26,12 @@ from pairs.code_gen.printer import Printer
 class CGen:
     temp_id = 0
 
-    def __init__(self, output, target, debug=False):
+    def __init__(self, ref, target, debug=False):
         self.sim = None
         self.target = None
+        self.print = None
+        self.ref = ref
         self.debug = debug
-        self.print = Printer(output)
 
     def assign_simulation(self, sim):
         self.sim = sim
@@ -39,6 +40,8 @@ class CGen:
         self.target = target
 
     def generate_program(self, ast_node):
+        ext = ".cu" if self.target.is_gpu() else ".cpp"
+        self.print = Printer(self.ref + ext)
         self.print.start()
         self.print("#include <math.h>")
         self.print("#include <stdbool.h>")