COMPASS  5.4.4
End-to-end AO simulation tool using GPU acceleration
modalBasis.py
1 
37 from shesha.ao import basis
38 import shesha.util.utilities as util
39 import shesha.util.make_pupil as mkP
40 import shesha.constants as scons
41 import scipy.ndimage
42 from scipy.sparse import csr_matrix
43 import numpy as np
44 
45 class ModalBasis(object):
46  """ This optimizer class handles all the modal basis and DM Influence functions
47  related operations.
48 
49  Attributes:
50  _config : (config) : Configuration parameters module
51 
52  _dms : (DmCompass) : DmCompass instance
53 
54  _target : (TargetCompass) : TargetCompass instance
55 
56  slaved_actus : TODO : docstring
57 
58  selected_actus : TODO : docstring
59 
60  couples_actus : TODO : docstring
61 
62  index_under_spiders : TODO : docstring
63 
64  modal_basis : (np.ndarray) : Last modal basis computed
65 
66  projection_matrix : (np.ndarray) : Last projection_matrix computed
67  """
68  def __init__(self, config, dms, target):
69  """ Instantiate a ModalBasis object
70 
71  Args:
72  config : (config) : Configuration parameters module
73 
74  dms : (DmCompass) : DmCompass instance
75 
76  target : (TargetCompass) : TargetCompass instance
77  """
78  self._config_config = config
79  self._dms_dms = dms
80  self._target_target = target
81  self.slaved_actusslaved_actus = None
82  self.selected_actusselected_actus = None
83  self.couples_actuscouples_actus = None
84  self.index_under_spidersindex_under_spiders = None
85  self.modal_basismodal_basis = None
86  self.projection_matrixprojection_matrix = None
87 
88  def compute_influ_basis(self, dm_index: int) -> csr_matrix:
89  """ Computes and return the influence function phase basis of the specified DM
90  as a sparse matrix
91 
92  Args:
93  dm_index : (int) : Index of the DM
94 
95  Returns:
96  influ_sparse : (csr_matrix) : influence function phases
97  """
98  return basis.compute_dm_basis(self._dms_dms._dms.d_dms[dm_index],
99  self._config_config.p_dms[dm_index],
100  self._config_config.p_geom)
101 
102  def compute_influ_delta(self, dm_index: int) -> np.ndarray:
103  """ Computes and return IF delta for the specified DM
104 
105  Args:
106  dm_index : (int) : Index of the DM
107 
108  Return:
109  influ_delta : (np.ndarray) : influence function deltas
110  """
111  ifsparse = basis.compute_dm_basis(self._dms_dms._dms.d_dms[dm_index],
112  self._config_config.p_dms[dm_index],
113  self._config_config.p_geom)
114  mpup = self._config_config.p_geom.get_mpupil()
115  npix_in_pup = np.sum(mpup)
116  ifdelta = ifsparse.dot(ifsparse.T) / np.sum(mpup)
117  return ifdelta.toarray()
118 
119  def compute_modes_to_volts_basis(self, modal_basis_type: str, *, merged: bool = False,
120  nbpairs: int = None, return_delta: bool = False) -> np.ndarray:
121  """ Computes a given modal basis ("KL2V", "Btt", "Btt_petal") and return the 2 transfer matrices
122 
123  Args:
124  modal_basis_type : (str) : modal basis to compute ("KL2V", "Btt", "Btt_petal")
125 
126  Kwargs:
127  merged : (bool) : TODO description
128 
129  nbpairs : (int) : TODO description
130 
131  Returns:
132  modal_basis : (np.ndarray) : modes to volts matrix
133 
134  projection_matrix : (np.ndarray) : volts to modes matrix (None if "KL")
135  """
136  if (modal_basis_type == "KL2V"):
137  print("Computing KL2V basis...")
138  self.modal_basismodal_basis = basis.compute_KL2V(
139  self._config_config.p_controllers[0], self._dms_dms._dms,
140  self._config_config.p_dms, self._config_config.p_geom,
141  self._config_config.p_atmos, self._config_config.p_tel)
142  fnz = util.first_non_zero(self.modal_basismodal_basis, axis=0)
143  # Computing the sign of the first non zero element
144  #sig = np.sign(modal_basis[[fnz, np.arange(modal_basis.shape[1])]])
145  sig = np.sign(self.modal_basismodal_basis[tuple([
146  fnz, np.arange(self.modal_basismodal_basis.shape[1])
147  ])]) # pour remove le future warning!
148  self.modal_basismodal_basis *= sig[None, :]
149  projection_matrix = None
150  elif (modal_basis_type == "Btt"):
151  print("Computing Btt basis...")
152  self.modal_basismodal_basis, self.projection_matrixprojection_matrix = self.compute_btt_basiscompute_btt_basis(
153  merged=merged, nbpairs=nbpairs,
154  return_delta=return_delta)
155  fnz = util.first_non_zero(self.modal_basismodal_basis, axis=0)
156  # Computing the sign of the first non zero element
157  #sig = np.sign(modal_basis[[fnz, np.arange(modal_basis.shape[1])]])
158  sig = np.sign(self.modal_basismodal_basis[tuple([
159  fnz, np.arange(self.modal_basismodal_basis.shape[1])
160  ])]) # pour remove le future warning!
161  self.modal_basismodal_basis *= sig[None, :]
162  elif (modal_basis_type == "Btt_petal"):
163  print("Computing Btt with a Petal basis...")
164  self.modal_basismodal_basis, self.projection_matrixprojection_matrix = self.compute_btt_petalcompute_btt_petal()
165  else:
166  raise RuntimeError("Unsupported modal basis")
167 
168  return self.modal_basismodal_basis, self.projection_matrixprojection_matrix
169 
170  def compute_btt_basis(self, *, merged: bool = False, nbpairs: int = None,
171  return_delta: bool = False) -> np.ndarray:
172  """ Computes the so-called Btt modal basis. The <merged> flag allows merto merge
173  2x2 the actuators influence functions for actuators on each side of the spider (ELT case)
174 
175  Kwargs:
176  merged : (bool) : If True, merge 2x2 the actuators influence functions for
177  actuators on each side of the spider (ELT case). Default
178  is False
179 
180  nbpairs : (int) : Default is None. TODO : description
181 
182  return_delta : (bool) : If False (default), the function returns
183  Btt (modes to volts matrix),
184  and P (volts to mode matrix).
185  If True, returns delta = IF.T.dot(IF) / N
186  instead of P
187 
188  Returns:
189  Btt : (np.ndarray) : Btt modes to volts matrix
190 
191  projection_matrix : (np.ndarray) : volts to Btt modes matrix
192  """
193  from shesha.ao import basis
194  dms_basis = basis.compute_IFsparse(self._dms_dms._dms, self._config_config.p_dms, self._config_config.p_geom)
195  influ_basis = dms_basis[:-2,:]
196  tt_basis = dms_basis[-2:,:].toarray()
197  if (merged):
198  couples_actus, index_under_spiders = self.compute_merged_influcompute_merged_influ(0,
199  nbpairs=nbpairs)
200  influ_basis2 = influ_basis.copy()
201  index_remove = index_under_spiders.copy()
202  index_remove += list(couples_actus[:, 1])
203  print("Pairing Actuators...")
204  for i in range(couples_actus.shape[0]):
205  influ_basis2[couples_actus[i, 0], :] += influ_basis2[
206  couples_actus[i, 1], :]
207  print("Pairing Done")
208  boolarray = np.zeros(influ_basis2.shape[0], dtype=bool)
209  boolarray[index_remove] = True
210  self.slaved_actusslaved_actus = boolarray
211  self.selected_actusselected_actus = ~boolarray
212  self.couples_actuscouples_actus = couples_actus
213  self.index_under_spidersindex_under_spiders = index_under_spiders
214  influ_basis2 = influ_basis2[~boolarray, :]
215  influ_basis = influ_basis2
216 
217  self.bttbtt, self.projection_matrixprojection_matrix = basis.compute_btt(influ_basis.T, tt_basis.T, return_delta=return_delta)
218 
219  if (merged):
220  btt2 = np.zeros((len(boolarray) + 2, self.bttbtt.shape[1]))
221  btt2[np.r_[~boolarray, True, True], :] = self.bttbtt
222  btt2[couples_actus[:, 1], :] = btt2[couples_actus[:, 0], :]
223 
224  P2 = np.zeros((self.bttbtt.shape[1], len(boolarray) + 2))
225  P2[:, np.r_[~boolarray, True, True]] = self.projection_matrixprojection_matrix
226  P2[:, couples_actus[:, 1]] = P2[:, couples_actus[:, 0]]
227  self.bttbtt = btt2
228  self.projection_matrixprojection_matrix = P2
229 
230  return self.bttbtt, self.projection_matrixprojection_matrix
231 
232  def compute_merged_influ(self, dm_index : int, *, nbpairs: int = None) -> np.ndarray:
233  """ Used to compute merged IF from each side of the spider
234  for an ELT case (Petalling Effect)
235 
236  Args:
237  dm_index : (int) : DM index
238 
239  Kwargs:
240  nbpairs : (int) : Default is None. TODO : description
241 
242  Returns:
243  pairs : (np.ndarray) : TODO description
244 
245  discard : (list) : TODO description
246  """
247  p_geom = self._config_config.p_geom
248 
249 
250  cent = p_geom.pupdiam / 2. + 0.5
251  p_tel = self._config_config.p_tel
252  p_tel.t_spiders = 0.51
253  spup = mkP.make_pupil(p_geom.pupdiam, p_geom.pupdiam, p_tel, cent,
254  cent).astype(np.float32).T
255 
256  p_tel.t_spiders = 0.
257  spup2 = mkP.make_pupil(p_geom.pupdiam, p_geom.pupdiam, p_tel, cent,
258  cent).astype(np.float32).T
259 
260  spiders = spup2 - spup
261 
262  (spidersID, k) = scipy.ndimage.label(spiders)
263  spidersi = util.pad_array(spidersID, p_geom.ssize).astype(np.float32)
264  px_list_spider = [np.where(spidersi == i) for i in range(1, k + 1)]
265 
266  # DM positions in iPupil:
267  dm_posx = self._config_config.p_dms[dm_index]._xpos - 0.5
268  dm_posy = self._config_config.p_dms[dm_index]._ypos - 0.5
269  dm_pos_mat = np.c_[dm_posx, dm_posy].T # one actu per column
270 
271  pitch = self._config_config.p_dms[dm_index]._pitch
272  discard = np.zeros(len(dm_posx), dtype=bool)
273  pairs = []
274 
275  # For each of the k pieces of the spider
276  for k, px_list in enumerate(px_list_spider):
277  pts = np.c_[px_list[1],
278  px_list[0]] # x,y coord of pixels of the spider piece
279  # line_eq = [a, b]
280  # Which minimizes leqst squares of aa*x + bb*y = 1
281  line_eq = np.linalg.pinv(pts).dot(np.ones(pts.shape[0]))
282  aa, bb = line_eq[0], line_eq[1]
283 
284  # Find any point of the fitted line.
285  # For simplicity, the intercept with one of the axes x = 0 / y = 0
286  if np.abs(bb) < np.abs(aa): # near vertical
287  one_point = np.array([1 / aa, 0.])
288  else: # otherwise
289  one_point = np.array([0., 1 / bb])
290 
291  # Rotation that aligns the spider piece to the horizontal
292  rotation = np.array([[-bb, aa], [-aa, -bb]]) / (aa**2 + bb**2)**.5
293 
294  # Rotated the spider mask
295  rotated_px = rotation.dot(pts.T - one_point[:, None])
296  # Min and max coordinates along the spider length - to filter actuators that are on
297  # 'This' side of the pupil and not the other side
298  min_u, max_u = rotated_px[0].min() - 5. * pitch, rotated_px[0].max(
299  ) + 5. * pitch
300 
301  # Rotate the actuators
302  rotated_actus = rotation.dot(dm_pos_mat - one_point[:, None])
303  sel_good_side = (rotated_actus[0] > min_u) & (rotated_actus[0] < max_u)
304  threshold = 0.05
305  # Actuators below this piece of spider
306  sel_discard = (np.abs(rotated_actus[1]) < threshold * pitch) & sel_good_side
307  discard |= sel_discard
308 
309  # Actuator 'near' this piece of spider
310  sel_pairable = (np.abs(rotated_actus[1]) > threshold * pitch) & \
311  (np.abs(rotated_actus[1]) < 1. * pitch) & \
312  sel_good_side
313 
314  pairable_index = np.where(sel_pairable)[0] # Indices of these actuators
315  u_coord = rotated_actus[
316  0, sel_pairable] # Their linear coord along the spider major axis
317 
318  order = np.sort(u_coord) # Sort by linear coordinate
319  order_index = pairable_index[np.argsort(
320  u_coord)] # And keep track of original indexes
321 
322  # i = 0
323  # while i < len(order) - 1:
324  if (nbpairs is None):
325  i = 0
326  ii = len(order) - 1
327  else:
328  i = len(order) // 2 - nbpairs
329  ii = len(order) // 2 + nbpairs
330  while (i < ii):
331  # Check if next actu in sorted order is very close
332  # Some lonely actuators may be hanging in this list
333  if np.abs(order[i] - order[i + 1]) < .2 * pitch:
334  pairs += [(order_index[i], order_index[i + 1])]
335  i += 2
336  else:
337  i += 1
338  print('To discard: %u actu' % np.sum(discard))
339  print('%u pairs to slave' % len(pairs))
340  if np.sum(discard) == 0:
341  discard = []
342  else:
343  list(np.where(discard)[0])
344  return np.asarray(pairs), list(np.where(discard)[0])
345 
346  def compute_btt_petal(self) -> np.ndarray:
347  """ Computes a Btt modal basis with Pistons filtered
348 
349  Returns:
350  Btt : (np.ndarray) : Btt modes to volts matrix
351 
352  P : (np.ndarray) : volts to Btt modes matrix
353  """
354  pzt_index = np.where([d.type is scons.DmType.PZT for d in self._config_config.p_dms])[0][0]
355  influ_pzt = self.compute_influ_basiscompute_influ_basis(pzt_index)
356  petal_dm_index = np.where([
357  d.influ_type is scons.InfluType.PETAL for d in self._config_config.p_dms
358  ])[0][0]
359  influ_petal = self.compute_influ_basiscompute_influ_basis(petal_dm_index)
360  tt_index = np.where([d.type is scons.DmType.TT for d in self._config_config.p_dms])[0][0]
361  influ_tt = self.compute_influ_basiscompute_influ_basis(tt_index).toarray()
362 
363  self.modal_basismodal_basis, self.projection_matrixprojection_matrix = basis.compute_btt(influ_pzt.T, influ_tt.T, influ_petal=influ_petal)
364  return self.modal_basismodal_basis, self.projection_matrixprojection_matrix
365 
366  def compute_phase_to_modes(self, modal_basis: np.ndarray) -> np.ndarray:
367  """ Return the phase to modes matrix by using the given modal basis
368 
369  Args:
370  modal_basis : (np.ndarray) : Modal basis matrix
371 
372  Returns:
373  phase_to_modes : (np.ndarray) : phase to modes matrix
374  """
375  nbmode = modal_basis.shape[1]
376  phase = self._target_target.get_tar_phase(0)
377  phase_to_modes = np.zeros((nbmode, phase.shape[0], phase.shape[1]))
378  S = np.sum(self._config_config.p_geom._spupil)
379  for i in range(nbmode):
380  self._dms_dms.set_command((modal_basis[:, i]).copy())
381  # self.next(see_atmos=False)
382  self._target_target.raytrace(0, dms=self._dms_dms, ncpa=False, reset=True)
383  phase = self._target_target.get_tar_phase(0, pupil=True)
384  # Normalisation pour les unites rms en microns !!!
385  norm = np.sqrt(np.sum((phase)**2) / S)
386  if norm == 0: norm = 1
387  phase_to_modes[i] = phase / norm
388  return phase_to_modes
This optimizer class handles all the modal basis and DM Influence functions related operations.
Definition: modalBasis.py:50
np.ndarray compute_merged_influ(self, int dm_index, *int nbpairs=None)
Used to compute merged IF from each side of the spider for an ELT case (Petalling Effect)
Definition: modalBasis.py:272
np.ndarray compute_phase_to_modes(self, np.ndarray modal_basis)
Return the phase to modes matrix by using the given modal basis.
Definition: modalBasis.py:400
np.ndarray compute_influ_delta(self, int dm_index)
Computes and return IF delta for the specified DM.
Definition: modalBasis.py:136
modal_basis
(np.ndarray) : Last modal basis computed
Definition: modalBasis.py:111
csr_matrix compute_influ_basis(self, int dm_index)
Computes and return the influence function phase basis of the specified DM as a sparse matrix.
Definition: modalBasis.py:123
projection_matrix
(np.ndarray) : Last projection_matrix computed
Definition: modalBasis.py:112
def __init__(self, config, dms, target)
Definition: modalBasis.py:103
np.ndarray compute_btt_petal(self)
Computes a Btt modal basis with Pistons filtered.
Definition: modalBasis.py:379
np.ndarray compute_btt_basis(self, *bool merged=False, int nbpairs=None, bool return_delta=False)
Computes the so-called Btt modal basis.
Definition: modalBasis.py:218
np.ndarray compute_modes_to_volts_basis(self, str modal_basis_type, *bool merged=False, int nbpairs=None, bool return_delta=False)
Computes a given modal basis ("KL2V", "Btt", "Btt_petal") and return the 2 transfer matrices.
Definition: modalBasis.py:161
Python package for AO operations on COMPASS simulation.
Definition: ao/__init__.py:1
Numerical constants for shesha and config enumerations for safe-typing.
Definition: constants.py:1
Pupil creation functions.
Definition: make_pupil.py:1
Basic utilities function.
Definition: utilities.py:1