Explorar el Código

Removed possible usage of reikna fft

This is the fft on GPU package
Miriam Brosi hace 7 años
padre
commit
e2e8478961
Se han modificado 1 ficheros con 40 adiciones y 39 borrados
  1. 40 39
      KCG/base/backend/dataset.py

+ 40 - 39
KCG/base/backend/dataset.py

@@ -17,45 +17,46 @@ def _pad_array(array):
     padded[:, :width] = array[:, :pwidth]
     return padded
 
-try:
-    import reikna.cluda
-    import reikna.fft
-
-    _plans = {}
-    _in_buffers = {}
-    _out_buffers = {}
-
-    _api = reikna.cluda.ocl_api()
-    _thr = _api.Thread.create()
-
-    def _fft(array):
-        start = time.time()
-        padded = _pad_array(array).astype(np.complex64)
-        height, width = padded.shape
-
-        if width in _plans:
-            fft = _plans[width]
-            in_dev = _in_buffers[width]
-            out_dev = _out_buffers[width]
-        else:
-            fft = reikna.fft.FFT(padded, axes=(1,)).compile(_thr)
-            in_dev = _thr.to_device(padded)
-            out_dev = _thr.empty_like(in_dev)
-            _plans[width] = fft
-            _in_buffers[width] = in_dev
-            _out_buffers[width] = out_dev
-
-        fft(out_dev, in_dev)
-        logging.debug("GPU fft: {} s".format(time.time() - start))
-        return out_dev.get()[:, :width / 2 + 1]
-
-except ImportError:
-    logging.debug("Failed to import reikna package. Falling back to Numpy FFT.")
-    def _fft(array):
-        start = time.time()
-        freqs = np.fft.rfft(_pad_array(array))
-        logging.debug("np fft: {} s".format(time.time() - start))
-        return freqs
+#try:
+    #import reikna.cluda
+    #import reikna.fft
+
+    #_plans = {}
+    #_in_buffers = {}
+    #_out_buffers = {}
+
+    #_api = reikna.cluda.ocl_api()
+    #_thr = _api.Thread.create()
+
+    #def _fft(array):
+        #start = time.time()
+        #padded = _pad_array(array).astype(np.complex64)
+        #height, width = padded.shape
+
+        #if width in _plans:
+            #fft = _plans[width]
+            #in_dev = _in_buffers[width]
+            #out_dev = _out_buffers[width]
+        #else:
+            #fft = reikna.fft.FFT(padded, axes=(1,)).compile(_thr)
+            #in_dev = _thr.to_device(padded)
+            #out_dev = _thr.empty_like(in_dev)
+            #_plans[width] = fft
+            #_in_buffers[width] = in_dev
+            #_out_buffers[width] = out_dev
+
+        #fft(out_dev, in_dev)
+        #logging.debug("GPU fft: {} s".format(time.time() - start))
+        #return out_dev.get()[:, :width / 2 + 1]
+    #logging.info("Using GPU based FFT!")
+
+#except ImportError:
+    #logging.debug("Failed to import reikna package. Falling back to Numpy FFT.")
+def _fft(array):
+    start = time.time()
+    freqs = np.fft.rfft(_pad_array(array))
+    logging.debug("np fft: {} s".format(time.time() - start))
+    return freqs
 
 
 BUNCHES_PER_TURN = config.bunches_per_turn