COMPASS  5.0.0
End-to-end AO simulation tool using GPU acceleration
closed_loop_mpi.py
1 import os
2 
3 import cProfile
4 import pstats as ps
5 
6 import sys
7 import numpy as np
8 import carmaWrap as ch
9 import shesha as ao
10 import time
11 
12 rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
13 c = ch.carmaWrap_context()
14 c.set_active_device(rank % c.get_ndevice())
15 
16 # Delay import because of cuda_aware
17 # mpi_init called during the import
18 import mpi4py
19 from mpi4py import MPI
20 import hdf5_util as h5u
21 
22 comm = MPI.COMM_WORLD
23 comm_size = comm.Get_size()
24 rank = comm.Get_rank()
25 
26 print("TEST SHESHA\n closed loop with MPI")
27 
28 if (len(sys.argv) != 2):
29  error = 'command line should be:"python test.py parameters_filename"\n with "parameters_filename" the path to the parameters file'
30  raise Exception(error)
31 
32 # get parameters from file
33 param_file = sys.argv[1]
34 if (param_file.split('.')[-1] == b"py"):
35  filename = param_file.split('/')[-1]
36  param_path = param_file.split(filename)[0]
37  sys.path.insert(0, param_path)
38  exec("import %s as config" % filename.split(".py")[0])
39  sys.path.remove(param_path)
40 elif (param_file.split('.')[-1] == b"h5"):
41  sys.path.insert(0, os.environ["SHESHA_ROOT"] + "/data/par/par4bench/")
42  import scao_sh_16x16_8pix as config
43  sys.path.remove(os.environ["SHESHA_ROOT"] + "/data/par/par4bench/")
44  h5u.configFromH5(param_file, config)
45 else:
46  raise ValueError("Parameter file extension must be .py or .h5")
47 
48 print("param_file is", param_file)
49 
50 if (hasattr(config, "simul_name")):
51  if (config.simul_name is None):
52  simul_name = ""
53  else:
54  simul_name = config.simul_name
55 else:
56  simul_name = ""
57 print("simul name is", simul_name)
58 
59 matricesToLoad = {}
60 if (simul_name == b""):
61  clean = 1
62 else:
63  clean = 0
64  param_dict = h5u.params_dictionary(config)
65  matricesToLoad = h5u.checkMatricesDataBase(os.environ["SHESHA_ROOT"] + "/data/",
66  config, param_dict)
67 
68 # initialisation:
69 # wfs
70 print("->wfs")
71 wfs, tel = ao.wfs_init(config.p_wfss, config.p_atmos, config.p_tel, config.p_geom,
72  config.p_target, config.p_loop, comm_size, rank, config.p_dms)
73 
74 # atmos
75 print("->atmos")
76 atm = ao.atmos_init(c, config.p_atmos, config.p_tel, config.p_geom, config.p_loop,
77  rank=rank, load=matricesToLoad)
78 
79 # dm
80 print("->dm")
81 dms = ao.dm_init(config.p_dms, config.p_wfss, wfs, config.p_geom, config.p_tel)
82 
83 # target
84 print("->target")
85 tar = ao.target_init(c, tel, config.p_target, config.p_atmos, config.p_geom,
86  config.p_tel, config.p_dms)
87 
88 # rtc
89 print("->rtc")
90 rtc = ao.rtc_init(tel, wfs, config.p_wfss, dms, config.p_dms, config.p_geom,
91  config.p_rtc, config.p_atmos, atm, config.p_tel, config.p_loop,
92  clean=clean, simul_name=simul_name, load=matricesToLoad)
93 
94 if not clean and rank == 0:
95  h5u.validDataBase(os.environ["SHESHA_ROOT"] + "/data/", matricesToLoad)
96 
97 comm.Barrier()
98 if (rank == 0):
99  print("====================")
100  print("init done")
101  print("====================")
102  print("objects initialzed on GPU:")
103  print("--------------------------------------------------------")
104  print(atm)
105  print(wfs)
106  print(dms)
107  print(tar)
108  print(rtc)
109 
110  print("----------------------------------------------------")
111  print("iter# | S.E. SR | L.E. SR | Est. Rem. | framerate")
112  print("----------------------------------------------------")
113 comm.Barrier()
114 
115 mimg = 0. # initializing average image
116 
117 #import matplotlib.pyplot as pl
118 
119 
120 def loop(n):
121  # if(rank==0):
122  #fig,((turbu,image),(shak,defMir))=pl.subplots(2,2, figsize=(15,15))
123  # pl.ion()
124  # pl.show()
125 
126  t0 = time.time()
127  for i in range(n):
128  if (rank == 0):
129  atm.move_atmos()
130  for t in range(config.p_target.ntargets):
131  tar.atmos_trace(t, atm, tel)
132  tar.dmtrace(t, dms)
133  for w in range(len(config.p_wfss)):
134  wfs.sensors_trace(w, "all", tel, atm, dms)
135  wfs.Bcast_dscreen()
136  for w in range(len(config.p_wfss)):
137  wfs.sensors_compimg(w)
138  wfs.gather_bincube(w)
139  if (rank == 0):
140  rtc.docentroids(0)
141  rtc.docontrol(0)
142  rtc.applycontrol(0, dms)
143 
144  if ((i + 1) % 50 == 0):
145  # s=rtc.get_centroids(0)
146  if (rank == 0):
147  """ FOR DEBUG PURPOSE
148  turbu.clear()
149  image.clear()
150  shak.clear()
151  defMir.clear()
152 
153  screen=atm.get_screen(0.)
154 
155  im=tar.get_image(0,"se")
156  im=np.roll(im,im.shape[0]/2,axis=0)
157  im=np.roll(im,im.shape[1]/2,axis=1)
158 
159  #sh=wfs.get_binimg(0)
160 
161  dm=dms.get_dm("pzt",0.)
162 
163  f1=turbu.matshow(screen,cmap='Blues_r')
164  f2=image.matshow(im,cmap='Blues_r')
165  #f3=shak.matshow(sh,cmap='Blues_r')
166  f4=defMir.matshow(dm)
167  pl.draw()
168 
169 
170  c=rtc.get_command(0)
171  v=rtc.get_voltages(0)
172 
173  sh_file="dbg/shak_"+str(i)+"_np_"+str(comm.Get_size())+".npy"
174  im_file="dbg/imag_"+str(i)+"_np_"+str(comm.Get_size())+".npy"
175  dm_file="dbg/DM_"+str(i)+"_np_"+str(comm.Get_size())+".npy"
176  s_file="dbg/cent_"+str(i)+"_np_"+str(comm.Get_size())+".npy"
177  c_file="dbg/comm_"+str(i)+"_np_"+str(comm.Get_size())+".npy"
178  v_file="dbg/volt_"+str(i)+"_np_"+str(comm.Get_size())+".npy"
179 
180  np.save(sh_file,sh)
181  np.save(im_file,im)
182  np.save(dm_file,dm)
183  np.save(s_file,s)
184  np.save(c_file,c)
185  np.save(v_file,v)
186  """
187 
188  strehltmp = tar.get_strehl(0)
189  print("%5d" % (i + 1), " %1.5f" % strehltmp[0],
190  " %1.5f" % strehltmp[1])
191 
192  t1 = time.time()
193  print(rank, "| loop execution time:", t1 - t0, " (", n, "iterations), ",
194  (t1 - t0) / n, "(mean) ", n / (t1 - t0), "Hz")
195 
196 
197 loop(config.p_loop.niter)
closed_loop_mpi.loop
def loop(n)
Definition: closed_loop_mpi.py:120