COMPASS  5.4.4
End-to-end AO simulation tool using GPU acceleration
coronagraph_init.py
1 
37 import os
38 import numpy as np
39 import astropy.io.fits as pfits
40 import shesha.config as conf
41 import shesha.constants as scons
42 import shesha.util.coronagraph_utils as util
43 
44 # ---------------------------------- #
45 # Initialization functions #
46 # ---------------------------------- #
47 
48 
49 # TODO : add some checks (dimension, type, etc...)
50 
51 def init_coronagraph(p_corono: conf.Param_corono, pupdiam):
52  """ Initialize the coronagraph
53  """
54  wavelength_0 = p_corono._wavelength_0
55  delta_wav = p_corono._delta_wav
56  nb_wav = p_corono._nb_wav
57 
58  # wavelength array init
59  if (delta_wav == 0):
60  p_corono.set_nb_wav(1)
61  p_corono.set_wav_vec(np.array([wavelength_0]))
62  elif (nb_wav == 1):
63  p_corono.set_delta_wav(0.)
64  p_corono.set_wav_vec(np.array([wavelength_0]))
65  else:
66  wav_vec = np.linspace(wavelength_0 - delta_wav / 2,
67  wavelength_0 + delta_wav / 2,
68  num=nb_wav,
69  endpoint=True)
70  p_corono.set_wav_vec(wav_vec)
71 
72  # pupils and mask init
73  if p_corono._type == scons.CoronoType.SPHERE_APLC:
74  init_sphere_aplc(p_corono, pupdiam)
75  elif p_corono._type == scons.CoronoType.PERFECT:
76  init_perfect_coronagraph(p_corono, pupdiam)
77  else:
78  init_apodizer(p_corono, pupdiam)
79  init_focal_plane_mask(p_corono)
80  init_lyot_stop(p_corono, pupdiam)
81 
82 def init_sphere_aplc(p_corono: conf.Param_corono, pupdiam):
83  """ Dedicated function for SPHERE APLC coronagraph init
84 
85  References:
86  APLC data : SPHERE user manual, appendix A6 'NIR coronagraphs'
87  https://www.eso.org/sci/facilities/paranal/instruments/sphere/doc.html
88  IRDIS data : ESO instrument description
89  https://www.eso.org/sci/facilities/paranal/instruments/sphere/inst.html
90  """
91  # apodizer init
92  p_corono.set_apodizer_name(scons.ApodizerType.SPHERE_APLC_APO1)
93  init_apodizer(p_corono, pupdiam)
94 
95  # fpm init
96  if p_corono._focal_plane_mask_name == None:
97  p_corono.set_focal_plane_mask_name(scons.FpmType.SPHERE_APLC_fpm_ALC2)
98  init_focal_plane_mask(p_corono)
99 
100  # lyot stop init
101  p_corono.set_lyot_stop_name(scons.LyotStopType.SPHERE_APLC_LYOT_STOP)
102  init_lyot_stop(p_corono, pupdiam)
103 
104  # image init
105  if p_corono._dim_image == None:
106  p_corono.set_dim_image(256)
107  if p_corono._image_sampling == None:
108  irdis_plate_scale = 12.25 # [mas]
109  VLT_pupil_diameter = 8 # [m]
110  lambda_over_D = p_corono._wavelength_0 * 1e-6 / VLT_pupil_diameter # [rad]
111  image_sampling = (lambda_over_D * 180 / np.pi * 3600 * 1000) / irdis_plate_scale
112  p_corono.set_image_sampling(image_sampling)
113 
114 def init_perfect_coronagraph(p_corono: conf.Param_corono, pupdiam):
115  """ Dedicated function for perfect coronagraph init
116  """
117  pass
118 
119 def init_apodizer(p_corono: conf.Param_corono, pupdiam):
120  """ Apodizer init
121  """
122  if p_corono._apodizer_name == scons.ApodizerType.SPHERE_APLC_APO1:
123  apodizer = util.make_sphere_apodizer(pupdiam)
124  elif p_corono._apodizer_name == None:
125  apodizer = np.ones((pupdiam, pupdiam))
126  elif isinstance(p_corono._apodizer_name, str):
127  if not os.path.exists(p_corono._apodizer_name):
128  error_string = "apodizer keyword (or path) '{}'".format(p_corono._apodizer_name) \
129  + " is not a known keyword (or path)"
130  raise ValueError(error_string)
131  apodizer = pfits.getdata(p_corono._apodizer_name)
132  else:
133  raise TypeError('apodizer name should be a string')
134  p_corono.set_apodizer(apodizer)
135 
136 def init_focal_plane_mask(p_corono: conf.Param_corono):
137  """ Focal plane mask init
138  """
139  if p_corono._focal_plane_mask_name == scons.FpmType.CLASSICAL_LYOT:
140  classical_lyot = True
141  elif p_corono._focal_plane_mask_name in (scons.FpmType.SPHERE_APLC_fpm_ALC1,
142  scons.FpmType.SPHERE_APLC_fpm_ALC2,
143  scons.FpmType.SPHERE_APLC_fpm_ALC3):
144  classical_lyot = True
145  if (p_corono._focal_plane_mask_name == scons.FpmType.SPHERE_APLC_fpm_ALC1):
146  fpm_radius_in_mas = 145 / 2 # [mas]
147  elif (p_corono._focal_plane_mask_name == scons.FpmType.SPHERE_APLC_fpm_ALC2):
148  fpm_radius_in_mas = 185 / 2 # [mas]
149  elif (p_corono._focal_plane_mask_name == scons.FpmType.SPHERE_APLC_fpm_ALC3):
150  fpm_radius_in_mas = 240 / 2 # [mas]
151  VLT_pupil_diameter = 8 # [m]
152  lambda_over_D = p_corono._wavelength_0 * 1e-6 / VLT_pupil_diameter # [rad]
153  fpm_radius = fpm_radius_in_mas / (lambda_over_D * 180 / np.pi * 3600 * 1000) # [lambda/D]
154  p_corono.set_lyot_fpm_radius(fpm_radius)
155  else:
156  classical_lyot = False
157 
158  if classical_lyot:
159  p_corono.set_babinet_trick(True)
160  if p_corono._fpm_sampling == None:
161  p_corono.set_fpm_sampling(20.)
162  lyot_fpm_radius_in_pix = p_corono._fpm_sampling * p_corono._lyot_fpm_radius
163  dim_fpm = 2 * int(lyot_fpm_radius_in_pix) + 2
164  p_corono.set_dim_fpm(dim_fpm)
165  fpm = util.classical_lyot_fpm(p_corono._lyot_fpm_radius,
166  p_corono._dim_fpm,
167  p_corono._fpm_sampling,
168  p_corono._wav_vec)
169 
170  elif isinstance(p_corono._focal_plane_mask_name, str):
171  if not os.path.exists(p_corono._focal_plane_mask_name):
172  error_string = "focal plane mask keyword (or path) '{}'".format(p_corono._focal_plane_mask_name) \
173  + " is not a known keyword (or path)"
174  raise ValueError(error_string)
175  fpm_array = pfits.getdata(p_corono._focal_plane_mask_name)
176  p_corono.set_dim_fpm(fpm_array.shape[0])
177  print(p_corono._dim_fpm)
178  if p_corono._fpm_sampling == None:
179  p_corono.set_fpm_sampling(p_corono._image_sampling)
180  if len(fpm_array.shape) == 2:
181  fpm = [fpm_array] * p_corono._nb_wav
182  elif len(fpm_array.shape) == 3:
183  fpm = []
184  for i in range(p_corono._nb_wav):
185  fpm.append(fpm_array[:, :, i])
186  p_corono.set_focal_plane_mask(fpm)
187 
188 def init_lyot_stop(p_corono: conf.Param_corono, pupdiam):
189  """ Lyot stop init
190  """
191  if p_corono._lyot_stop_name == scons.LyotStopType.SPHERE_APLC_LYOT_STOP:
192  lyot_stop = util.make_sphere_lyot_stop(pupdiam)
193  elif p_corono._lyot_stop_name == None:
194  lyot_stop = np.ones((pupdiam, pupdiam))
195  elif isinstance(p_corono._lyot_stop_name, str):
196  if not os.path.exists(p_corono._lyot_stop_name):
197  error_string = "Lyot stop keyword (or path) '{}'".format(p_corono._lyot_stop_name) \
198  + " is not a known keyword (or path)"
199  raise ValueError(error_string)
200  lyot_stop = pfits.getdata(p_corono._lyot_stop_name)
201  else:
202  raise TypeError('Lyot stop name should be a string')
203  p_corono.set_lyot_stop(lyot_stop)
204 
205 def init_mft(p_corono: conf.Param_corono, pupdiam, planes, center_on_pixel=False):
206  """ Initialize mft matrices
207  """
208  dim_fpm = p_corono._dim_fpm
209  fpm_sampling = p_corono._fpm_sampling
210  dim_image = p_corono._dim_image
211  image_sampling = p_corono._image_sampling
212  wavelength_0 = p_corono._wavelength_0
213  wav_vec = p_corono._wav_vec
214 
215  norm0 = np.zeros(len(wav_vec))
216 
217  if planes == 'apod_to_fpm':
218  AA_apod_to_fpm = np.zeros((dim_fpm, pupdiam, len(wav_vec)), dtype=np.complex64)
219  BB_apod_to_fpm = np.zeros((pupdiam, dim_fpm, len(wav_vec)), dtype=np.complex64)
220 
221  for w, wavelength in enumerate(wav_vec):
222  wav_ratio = wavelength / wavelength_0
223  nbres = dim_fpm / fpm_sampling / wav_ratio
224  AA_apod_to_fpm[:,:,w], BB_apod_to_fpm[:,:,w], norm0[w] = mft_matrices(pupdiam,
225  dim_fpm, nbres)
226 
227  return AA_apod_to_fpm, BB_apod_to_fpm, norm0
228 
229  elif planes == 'fpm_to_lyot':
230  AA_fpm_to_lyot = np.zeros((pupdiam, dim_fpm, len(wav_vec)), dtype=np.complex64)
231  BB_fpm_to_lyot = np.zeros((dim_fpm, pupdiam, len(wav_vec)), dtype=np.complex64)
232 
233  for w, wavelength in enumerate(wav_vec):
234  wav_ratio = wavelength / wavelength_0
235  nbres = dim_fpm / fpm_sampling / wav_ratio
236 
237  AA_fpm_to_lyot[:,:,w], BB_fpm_to_lyot[:,:,w], norm0[w] = mft_matrices(dim_fpm,
238  pupdiam, nbres,inverse=True)
239 
240  return AA_fpm_to_lyot, BB_fpm_to_lyot, norm0
241 
242  elif planes == 'lyot_to_image':
243  AA_lyot_to_image = np.zeros((dim_image, pupdiam, len(wav_vec)), dtype=np.complex64)
244  BB_lyot_to_image = np.zeros((pupdiam, dim_image, len(wav_vec)), dtype=np.complex64)
245 
246  for w, wavelength in enumerate(wav_vec):
247  wav_ratio = wavelength / wavelength_0
248  nbres = dim_image / image_sampling / wav_ratio
249  if center_on_pixel:
250  AA_lyot_to_image[:,:,w], BB_lyot_to_image[:,:,w], norm0[w] = mft_matrices(pupdiam,
251  dim_image,
252  nbres,
253  X_offset_output=0.5,
254  Y_offset_output=0.5)
255  else:
256  AA_lyot_to_image[:,:,w], BB_lyot_to_image[:,:,w], norm0[w] = mft_matrices(pupdiam,
257  dim_image,
258  nbres)
259 
260  return AA_lyot_to_image, BB_lyot_to_image, norm0
261 
262 
263 # ---------------------------------------- #
264 # Matrix Fourier Transform (MFT) #
265 # Generic functions from ASTERIX #
266 # ---------------------------------------- #
267 
268 def mft_matrices(dim_input,
269  dim_output,
270  nbres,
271  real_dim_input=None,
272  inverse=False,
273  norm='ortho',
274  X_offset_input=0,
275  Y_offset_input=0,
276  X_offset_output=0,
277  Y_offset_output=0):
278  """
279  docstring
280  """
281  # check dimensions and type of dim_input
282  error_string_dim_input = "'dim_input' must be an int (square input) or tuple of ints of dimension 2"
283  if np.isscalar(dim_input):
284  if isinstance(dim_input, int):
285  dim_input_x = dim_input
286  dim_input_y = dim_input
287  else:
288  raise TypeError(dim_input)
289  elif isinstance(dim_input, tuple):
290  if all(isinstance(dims, int) for dims in dim_input) & (len(dim_input) == 2):
291  dim_input_x = dim_input[0]
292  dim_input_y = dim_input[1]
293  else:
294  raise TypeError(error_string_dim_input)
295  else:
296  raise TypeError(error_string_dim_input)
297 
298  # check dimensions and type of real_dim_input
299  if real_dim_input == None:
300  real_dim_input = dim_input
301  error_string_real_dim_input = "'real_dim_input' must be an int (square input pupil) or tuple of ints of dimension 2"
302  if np.isscalar(real_dim_input):
303  if isinstance(real_dim_input, int):
304  real_dim_input_x = real_dim_input
305  real_dim_input_y = real_dim_input
306  else:
307  raise TypeError(error_string_real_dim_input)
308  elif isinstance(real_dim_input, tuple):
309  if all(isinstance(dims, int) for dims in real_dim_input) & (len(real_dim_input) == 2):
310  real_dim_input_x = real_dim_input[0]
311  real_dim_input_y = real_dim_input[1]
312  else:
313  raise TypeError(error_string_real_dim_input)
314  else:
315  raise TypeError(error_string_real_dim_input)
316 
317  # check dimensions and type of dim_output
318  error_string_dim_output = "'dim_output' must be an int (square output) or tuple of ints of dimension 2"
319  if np.isscalar(dim_output):
320  if isinstance(dim_output, int):
321  dim_output_x = dim_output
322  dim_output_y = dim_output
323  else:
324  raise TypeError(error_string_dim_output)
325  elif isinstance(dim_output, tuple):
326  if all(isinstance(dims, int) for dims in dim_output) & (len(dim_output) == 2):
327  dim_output_x = dim_output[0]
328  dim_output_y = dim_output[1]
329  else:
330  raise TypeError(error_string_dim_output)
331  else:
332  raise TypeError(error_string_dim_output)
333 
334  # check dimensions and type of nbres
335  error_string_nbr = "'nbres' must be an float or int (square output) or tuple of float or int of dimension 2"
336  if np.isscalar(nbres):
337  if isinstance(nbres, (float, int)):
338  nbresx = float(nbres)
339  nbresy = float(nbres)
340  else:
341  raise TypeError(error_string_nbr)
342  elif isinstance(nbres, tuple):
343  if all(isinstance(nbresi, (float, int)) for nbresi in nbres) & (len(nbres) == 2):
344  nbresx = float(nbres[0])
345  nbresy = float(nbres[1])
346  else:
347  raise TypeError(error_string_nbr)
348  else:
349  raise TypeError(error_string_nbr)
350 
351  if real_dim_input is not None:
352  nbresx = nbresx * dim_input_x / real_dim_input_x
353  nbresy = nbresy * dim_input_y / real_dim_input_y
354 
355  X0 = dim_input_x / 2 + X_offset_input
356  Y0 = dim_input_y / 2 + Y_offset_input
357 
358  X1 = dim_output_x / 2 + X_offset_output
359  Y1 = dim_output_y / 2 + Y_offset_output
360 
361  xx0 = ((np.arange(dim_input_x) - X0 + 1 / 2) / dim_input_x) # Entrance image
362  xx1 = ((np.arange(dim_input_y) - Y0 + 1 / 2) / dim_input_y) # Entrance image
363  uu0 = ((np.arange(dim_output_x) - X1 + 1 / 2) / dim_output_x) * nbresx # Fourier plane
364  uu1 = ((np.arange(dim_output_y) - Y1 + 1 / 2) / dim_output_y) * nbresy # Fourier plane
365 
366  if not inverse:
367  if norm == 'backward':
368  norm0 = 1.
369  elif norm == 'forward':
370  norm0 = nbresx * nbresy / dim_input_x / dim_input_y / dim_output_x / dim_output_y
371  elif norm == 'ortho':
372  norm0 = np.sqrt(nbresx * nbresy / dim_input_x / dim_input_y / dim_output_x / dim_output_y)
373  sign_exponential = -1
374 
375  else:
376  if norm == 'backward':
377  norm0 = nbresx * nbresy / dim_input_x / dim_input_y / dim_output_x / dim_output_y
378  elif norm == 'forward':
379  norm0 = 1.
380  elif norm == 'ortho':
381  norm0 = np.sqrt(nbresx * nbresy / dim_input_x / dim_input_y / dim_output_x / dim_output_y)
382  sign_exponential = 1
383 
384  AA = np.exp(sign_exponential * 1j * 2 * np.pi * np.outer(uu0, xx0)).astype('complex64')
385  BB = np.exp(sign_exponential * 1j * 2 * np.pi * np.outer(xx1, uu1)).astype('complex64')
386 
387  return AA, BB, norm0
388 
389 def mft_multiplication(image, AA, BB, norm):
390  """ Computes matrix multiplication for MFT
391  """
392  return norm * ((AA @ image.astype('complex64')) @ BB)
Parameter classes for COMPASS.
Numerical constants for shesha and config enumerations for safe-typing.
Definition: constants.py:1
def init_apodizer(conf.Param_corono p_corono, pupdiam)
Apodizer init.
def init_focal_plane_mask(conf.Param_corono p_corono)
Focal plane mask init.
def init_perfect_coronagraph(conf.Param_corono p_corono, pupdiam)
Dedicated function for perfect coronagraph init.
def init_lyot_stop(conf.Param_corono p_corono, pupdiam)
Lyot stop init.
def mft_matrices(dim_input, dim_output, nbres, real_dim_input=None, inverse=False, norm='ortho', X_offset_input=0, Y_offset_input=0, X_offset_output=0, Y_offset_output=0)
docstring
def init_mft(conf.Param_corono p_corono, pupdiam, planes, center_on_pixel=False)
Initialize mft matrices.
def init_sphere_aplc(conf.Param_corono p_corono, pupdiam)
Dedicated function for SPHERE APLC coronagraph init.
def mft_multiplication(image, AA, BB, norm)
Computes matrix multiplication for MFT.
def init_coronagraph(conf.Param_corono p_corono, pupdiam)
Initialize the coronagraph.