326 def test_double_gemm():
338 matA = ch.obj_double(c, np.random.randn(sizem, sizek))
339 matAT = ch.obj_double(c, np.random.randn(sizek, sizem))
340 matB = ch.obj_double(c, np.random.randn(sizek, sizen))
341 matBT = ch.obj_double(c, np.random.randn(sizen, sizek))
342 matC = ch.obj_double(c, np.random.randn(sizem, sizen))
343 matC2 = ch.obj_double(c, np.random.randn(sizem, sizen))
344 matC3 = ch.obj_double(c, np.random.randn(sizem, sizen))
366 matA.gemm(matB,
'n',
'n', alpha, matC, beta)
367 matAT.gemm(matB,
't',
'n', alpha, matC2, beta)
368 matAT.gemm(matBT,
't',
't', alpha, matC3, beta)
369 matC4 = matA.gemm(matB,
'n',
'n', alpha)
370 matC5 = matAT.gemm(matB,
't',
'n', alpha)
371 matC6 = matAT.gemm(matBT,
't',
't', alpha)
373 C = alpha * A.dot(B) + beta * C
374 C2 = alpha * AT.T.dot(B) + beta * C2
375 C3 = alpha * AT.T.dot(BT.T) + beta * C3
376 C4 = alpha * A.dot(B)
377 C5 = alpha * AT.T.dot(B)
378 C6 = alpha * AT.T.dot(BT.T)
380 npt.assert_array_almost_equal(C, np.array(matC), decimal=2 * dec - 1)
381 npt.assert_array_almost_equal(C2, np.array(matC2), decimal=2 * dec - 1)
382 npt.assert_array_almost_equal(C3, np.array(matC3), decimal=2 * dec - 1)
383 npt.assert_array_almost_equal(C4, np.array(matC4), decimal=2 * dec - 1)
384 npt.assert_array_almost_equal(C5, np.array(matC5), decimal=2 * dec - 1)
385 npt.assert_array_almost_equal(C6, np.array(matC6), decimal=2 * dec - 1)