__init__.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import logging
  2. import threading
  3. import numpy as np
  4. import time
  5. def _pad_array(array):
  6. height, width = array.shape
  7. # Miriam uses floor hence the padding is actually a cutting. Will wait for
  8. # response if this is desired ...
  9. pwidth = 2**np.floor(np.log2(width))
  10. padded = np.zeros((height, pwidth))
  11. padded[:, :width] = array[:, :pwidth]
  12. return padded
  13. try:
  14. import reikna.cluda
  15. import reikna.fft
  16. _plans = {}
  17. _in_buffers = {}
  18. _out_buffers = {}
  19. _api = reikna.cluda.ocl_api()
  20. _thr = _api.Thread.create()
  21. def _fft(array):
  22. start = time.time()
  23. padded = _pad_array(array).astype(np.complex64)
  24. height, width = padded.shape
  25. if width in _plans:
  26. fft = _plans[width]
  27. in_dev = _in_buffers[width]
  28. out_dev = _out_buffers[width]
  29. else:
  30. fft = reikna.fft.FFT(padded, axes=(1,)).compile(_thr)
  31. in_dev = _thr.to_device(padded)
  32. out_dev = _thr.empty_like(in_dev)
  33. _plans[width] = fft
  34. _in_buffers[width] = in_dev
  35. _out_buffers[width] = out_dev
  36. fft(out_dev, in_dev)
  37. logging.info("GPU fft: {} s".format(time.time() - start))
  38. return out_dev.get()[:, :width / 2 + 1]
  39. except ImportError:
  40. logging.info("Failed to import reikna package. Falling back to Numpy FFT.")
  41. def _fft(array):
  42. start = time.time()
  43. freqs = np.fft.rfft(_pad_array(array))
  44. logging.info("np fft: {} s".format(time.time() - start))
  45. return freqs
  46. BUNCHES_PER_TURN = 184
  47. HEADER_SIZE_BYTES = 32
  48. class DataSet(object):
  49. def __init__(self, array, filename):
  50. self.filename = filename
  51. self.array = array
  52. self._heatmaps = {}
  53. self._ffts = {}
  54. def bunch(self, number):
  55. return self.array[self.array[:, 4] == number]
  56. def num_bunches(self):
  57. return self.array.shape[0]
  58. def num_turns(self):
  59. return self.num_bunches() / BUNCHES_PER_TURN
  60. def heatmap(self, adc=1, frm=0, to=-1, bunch_frm=0, bunch_to=-1):
  61. if not 1 <= adc <= 4:
  62. raise ValueError('adc must be in [1,4]')
  63. if not adc in self._heatmaps:
  64. heatmap = self.array[:,adc-1].reshape(-1, BUNCHES_PER_TURN).transpose()
  65. self._heatmaps[adc] = heatmap
  66. return self._heatmaps[adc][bunch_frm:bunch_to, frm:to]
  67. def fft(self, adc=1, frm=0, to=-1):
  68. if not 1 <= adc <= 4:
  69. raise ValueError('adc must be in [1,4]')
  70. # if not adc in self._ffts:
  71. # heatmap = self.heatmap(adc, frm, to)
  72. # self._ffts[adc] = np.fft.fft2(heatmap, axes=[1])
  73. # return self._ffts[adc]
  74. heatmap = self.heatmap(adc, frm, to)
  75. return _fft(heatmap)