COMPASS  5.0.0
End-to-end AO simulation tool using GPU acceleration
array.py
1 
37 
38 import numpy as np
39 from naga.context import Context
40 
41 from carmaWrap import obj_float, obj_double, obj_int, obj_float_complex, obj_double_complex, obj_uint16
42 try:
43  from carmaWrap import obj_half
44  USE_HALF = 1
45 except:
46  USE_HALF = 0
47 context = Context()
48 
49 
50 class DimensionsError(Exception):
51 
52  def __init__(self, value):
53  self.value = value
54 
55  def __str__(self):
56  return repr(self.value)
57 
58 
59 class Array():
60 
61  def __init__(self, data=None, shape=None, dtype=None):
62  if shape is not None:
63  if dtype is None:
64  dtype = np.float32
65  if isinstance(shape, tuple):
66  data = np.zeros(shape, dtype=dtype)
67  else:
68  raise TypeError("shape must be a tuple")
69 
70  if data is not None:
71  if isinstance(data, list):
72  data = np.array(data)
73  if isinstance(data, np.ndarray):
74  if dtype is not None:
75  data = data.astype(dtype)
76  if data.dtype == np.int64 or data.dtype == np.int32:
77  self.__data = obj_int(context.context, data)
78  elif data.dtype == np.float32:
79  self.__data = obj_float(context.context, data)
80  elif data.dtype == np.float64:
81  self.__data = obj_double(context.context, data)
82  elif USE_HALF and data.dtype == np.float16:
83  self.__data = obj_half(context.context, data)
84  elif data.dtype == np.complex64:
85  self.__data = obj_float_complex(context.context, data)
86  elif data.dtype == np.complex128:
87  self.__data = obj_double_complex(context.context, data)
88  else:
89  raise TypeError("Data type not implemented")
90  self.__dtype = data.dtype
91  self.__shape = data.shape
92  elif isinstance(data, obj_int):
93  self.__data = data
94  self.__dtype = np.int64
95  self.__shape = tuple(data.shape[k] for k in range(len(data.shape)))
96  elif isinstance(data, obj_float):
97  self.__data = data
98  self.__dtype = np.float32
99  self.__shape = tuple(data.shape[k] for k in range(len(data.shape)))
100  elif isinstance(data, obj_double):
101  self.__data = data
102  self.__dtype = np.float64
103  self._shape = tuple(data.shape[k] for k in range(len(data.shape)))
104  elif USE_HALF and isinstance(data, obj_half):
105  self.__data = data
106  self.__dtype = np.float16
107  self.__shape = tuple(data.shape[k] for k in range(len(data.shape)))
108  elif isinstance(data, obj_float_complex):
109  self.__data = data
110  self.__dtype = np.complex64
111  self.__shape = tuple(data.shape[k] for k in range(len(data.shape)))
112  elif isinstance(data, obj_double_complex):
113  self.__data = data
114  self.__dtype = np.complex128
115  self.__shape = tuple(data.shape[k] for k in range(len(data.shape)))
116  elif isinstance(data, obj_uint16):
117  self.__data = data
118  self.__dtype = np.uint16
119  self.__shape = tuple(data.shape[k] for k in range(len(data.shape)))
120  else:
121  raise TypeError("Data must be a list, a numpy array or a carmaWrap.obj")
122  self.__size = self.__data.nbElem
123  else:
124  raise AttributeError("You must provide data or shape at least")
125 
126  shape = property(lambda x: x.__shape)
127  dtype = property(lambda x: x.__dtype)
128  data = property(lambda x: x.__data)
129 
130  def __repr__(self):
131  return "GPU" + self.toarray().__repr__()
132 
133  def __add__(self, idata):
134  tmp = self.copy()
135  if isinstance(idata, Array):
136  tmp.data.axpy(1, idata.data)
137  elif isinstance(idata, np.ndarray):
138  tmp.data.axpy(1, Array(idata).data)
139  else:
140  raise TypeError("operator + is defined only between Arrays and np.arrays")
141  return tmp
142 
143  def __sub__(self, idata):
144  tmp = self.copy()
145  if isinstance(idata, Array):
146  tmp.data.axpy(-1, idata.data)
147  elif isinstance(idata, np.ndarray):
148  tmp.data.axpy(-1, Array(idata).data)
149  else:
150  raise TypeError("operator + is defined only between Arrays and np.arrays")
151  return tmp
152 
153  def __mul__(self, idata):
154  if isinstance(idata, float) or isinstance(idata, int):
155  tmp = self.copy()
156  tmp.data.scale(idata)
157  return tmp
158  else:
159  raise NotImplementedError("Operator not implemented yet")
160 
161  def __getitem__(self, indices):
162  return self.toarray().__getitem__(indices)
163 
164  def copy(self):
165  tmp = Array(shape=self.shape, dtype=self.dtype)
166  tmp.data.copy_from(self.data)
167  return tmp
168 
169  def dot(self, idata):
170  if isinstance(idata, np.ndarray):
171  if idata.dtype == self.dtype:
172  idata = Array(idata)
173  else:
174  raise TypeError("Data types must be the same for both arrays")
175  if isinstance(idata, Array):
176  if len(self.shape) == 1:
177  if len(idata.shape) == 1:
178  if idata.shape == self.shape:
179  result = self.data.dot(idata.data, 1, 1)
180  else:
181  raise DimensionsError("Dimensions mismatch")
182  elif len(idata.shape) == 2:
183  if idata.shape[0] == self.shape[0]:
184  result = Array(idata.data.gemv(self.data, op='T'))
185  else:
186  raise DimensionsError("Dimensions mismatch")
187  else:
188  raise DimensionsError("Arrays must be 1D or 2D max")
189  elif len(self.shape) == 2:
190  if len(idata.shape) == 1:
191  if idata.shape[0] == self.shape[1]:
192  result = Array(self.data.gemv(idata.data))
193  else:
194  raise DimensionsError("Dimensions mismatch")
195  elif len(idata.shape) == 2:
196  if self.shape[1] == idata.shape[0]:
197  result = Array(self.data.gemm(idata.data))
198  else:
199  raise DimensionsError("Dimensions mismatch")
200  else:
201  raise DimensionsError("Arrays must be 1D or 2D max")
202 
203  return result
204 
205  def argmax(self):
206  return self.data.aimax()
207 
208  def max(self):
209  return self.toarray()[self.argmax()]
210 
211  def argmin(self):
212  return self.data.aimin()
213 
214  def min(self):
215  return self.toarray()[self.argmin()]
216 
217  def sum(self):
218  return self.data.sum()
219 
220  def toarray(self):
221  tmp = np.array(self.data)
222  return tmp
223 
224 
225 def ones(shape, dtype=np.float32):
226  return Array(np.ones(shape, dtype=dtype))
227 
228 
229 def zeros(shape, dtype=np.float32):
230  return Array(np.zeros(shape, dtype=dtype))
naga.array.Array.__shape
__shape
Definition: array.py:91
naga.array.Array.toarray
def toarray(self)
Definition: array.py:220
naga.array.Array.__repr__
def __repr__(self)
Definition: array.py:130
naga.array.Array.dot
def dot(self, idata)
Definition: array.py:169
naga.array.Array.sum
def sum(self)
Definition: array.py:217
naga.array.Array.__dtype
__dtype
Definition: array.py:90
naga.array.Array.argmax
def argmax(self)
Definition: array.py:205
naga.array.Array.__mul__
def __mul__(self, idata)
Definition: array.py:153
naga.array.DimensionsError.__str__
def __str__(self)
Definition: array.py:55
naga.array.Array
Definition: array.py:59
naga.array.DimensionsError.__init__
def __init__(self, value)
Definition: array.py:52
naga.array.Array._shape
_shape
Definition: array.py:103
naga.array.Array.dtype
dtype
Definition: array.py:127
naga.context.Context
Python class for wrapping a CarmaContext.
Definition: context.py:45
naga.array.zeros
def zeros(shape, dtype=np.float32)
Definition: array.py:229
naga.context
Documentation for naga.
Definition: context.py:1
naga.array.Array.__size
__size
Definition: array.py:122
naga.array.Array.shape
shape
Definition: array.py:126
naga.array.DimensionsError.value
value
Definition: array.py:53
naga.array.Array.copy
def copy(self)
Definition: array.py:164
naga.array.Array.__init__
def __init__(self, data=None, shape=None, dtype=None)
Definition: array.py:61
naga.array.Array.min
def min(self)
Definition: array.py:214
naga.array.DimensionsError
Definition: array.py:50
naga.array.Array.__add__
def __add__(self, idata)
Definition: array.py:133
naga.array.Array.argmin
def argmin(self)
Definition: array.py:211
naga.array.Array.__data
__data
Definition: array.py:77
naga.array.Array.__sub__
def __sub__(self, idata)
Definition: array.py:143
naga.array.Array.__getitem__
def __getitem__(self, indices)
Definition: array.py:161
naga.array.Array.max
def max(self)
Definition: array.py:208
naga.array.Array.data
data
Definition: array.py:128
naga.array.ones
def ones(shape, dtype=np.float32)
Definition: array.py:225