make_burst_kernels.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """Generate burst laminographic backprojection OpenCL kernels."""
  2. import argparse
  3. import os
  4. IDX_TO_VEC_ELEM = dict(zip(range(10), range(10)))
  5. IDX_TO_VEC_ELEM[10] = 'a'
  6. IDX_TO_VEC_ELEM[11] = 'b'
  7. IDX_TO_VEC_ELEM[12] = 'c'
  8. IDX_TO_VEC_ELEM[13] = 'd'
  9. IDX_TO_VEC_ELEM[14] = 'e'
  10. IDX_TO_VEC_ELEM[15] = 'f'
  11. def fill_compute_template(tmpl, num_items, index):
  12. """Fill the template doing the pixel computation and texture fetch."""
  13. operation = '+' if index else ''
  14. access = '.s{}'.format(IDX_TO_VEC_ELEM[index]) if num_items > 1 else ''
  15. return tmpl.format(index, access, operation)
  16. def fill_kernel_template(input_tmpl, compute_tmpl, kernel_outer, kernel_inner, num_items):
  17. """Construct the whole kernel."""
  18. vector_length = num_items if num_items > 1 else ''
  19. computes = '\n'.join([fill_compute_template(compute_tmpl, num_items, i)
  20. for i in range(num_items)])
  21. inputs = '\n'.join([input_tmpl.format(i) for i in range(num_items)])
  22. kernel_inner = kernel_inner.format(computes)
  23. return kernel_outer.format(num_items, inputs, vector_length, kernel_inner)
  24. def parse_args():
  25. """Parse command line arguments."""
  26. parser = argparse.ArgumentParser()
  27. parser.add_argument('filename', type=str, help='File name with the kernel template')
  28. parser.add_argument('burst', type=int, nargs='+',
  29. help='Number of projections processed by one kernel invocation')
  30. return parser.parse_args()
  31. def main():
  32. """execute program."""
  33. args = parse_args()
  34. allowed_bursts = [2 ** i for i in range(5)]
  35. in_tmpl = "read_only image2d_t projection_{},"
  36. common_filename = os.path.join(os.path.dirname(args.filename), 'common.in')
  37. defs_filename = os.path.join(os.path.dirname(args.filename), 'definitions.in')
  38. defs = open(defs_filename, 'r').read()
  39. kernel_outer = open(common_filename, 'r').read()
  40. comp_tmpl, kernel_inner = open(args.filename, 'r').read().split('\n%nl\n')
  41. kernels = defs + '\n'
  42. for burst in args.burst:
  43. if burst not in allowed_bursts:
  44. raise ValueError('Burst mode `{}` must be one of `{}`'.format(burst, allowed_bursts))
  45. kernels += fill_kernel_template(in_tmpl, comp_tmpl, kernel_outer, kernel_inner, burst)
  46. print kernels
  47. if __name__ == '__main__':
  48. main()