Commit f32b2342 authored by Martin Uecker's avatar Martin Uecker

reduce memory

parent 03e4559f
......@@ -111,7 +111,8 @@ static complex float* compute_linphases(int N, long lph_dims[N + 1], unsigned lo
static void compute_kern_basis(unsigned int N, const long krn_dims[N], complex float* krn,
static void compute_kern_basis(unsigned int N, unsigned int flags, const long pos[N],
const long krn_dims[N], complex float* krn,
const long bas_dims[N], const complex float* basis,
const long wgh_dims[N], const complex float* weights)
{
......@@ -164,30 +165,38 @@ static void compute_kern_basis(unsigned int N, const long krn_dims[N], complex f
baT_strs[5] = baT_strs[6];
baT_strs[6] = 0;
long krn2_dims[N];
md_copy_dims(N, krn2_dims, krn_dims);
long krn_strs[N];
md_calc_strides(N, krn_strs, krn2_dims, CFL_SIZE);
md_calc_strides(N, krn_strs, krn_dims, CFL_SIZE);
long ma2_dims[N];
md_tenmul_dims(N, ma2_dims, krn2_dims, max_dims, baT_dims);
md_tenmul_dims(N, ma2_dims, krn_dims, max_dims, baT_dims);
long ma3_dims[N];
md_select_dims(N, flags, ma3_dims, ma2_dims);
#if 1
long tmp_off = md_calc_offset(N - 1, max_strs, pos);
long bas_off = md_calc_offset(N - 1, baT_strs, pos);
#endif
md_zsmul(N, max_dims, tmp, tmp, (double)bas_dims[6]); // FIXME: Why?
md_ztenmulc2(N, ma2_dims, krn_strs, krn, max_strs, tmp, baT_strs, basis);
md_ztenmulc2(N, ma3_dims, krn_strs, krn,
max_strs, (void*)tmp + tmp_off,
baT_strs, (void*)basis + bas_off);
md_zsmul(N, krn2_dims, krn, krn, (double)bas_dims[6]); // FIXME: Why?
//md_zsmul(N, krn_dims, krn, krn, (double)bas_dims[6]); // FIXME: Why?
md_free(tmp);
}
static void compute_kern(unsigned int N, const long krn_dims[N], complex float* krn,
static void compute_kern(unsigned int N, unsigned int flags, const long pos[N],
const long krn_dims[N], complex float* krn,
const long bas_dims[N], const complex float* basis,
const long wgh_dims[N], const complex float* weights)
{
if (NULL != basis)
return compute_kern_basis(N, krn_dims, krn, bas_dims, basis, wgh_dims, weights);
return compute_kern_basis(N, flags, pos, krn_dims, krn, bas_dims, basis, wgh_dims, weights);
md_zfill(N, krn_dims, krn, 1.);
......@@ -209,7 +218,7 @@ static void compute_kern(unsigned int N, const long krn_dims[N], complex float*
complex float* compute_psf(unsigned int N, const long img_dims[N], const long trj_dims[N], const complex float* traj,
complex float* compute_psf(int N, const long img_dims[N], const long trj_dims[N], const complex float* traj,
const long bas_dims[N], const complex float* basis,
const long wgh_dims[N], const complex float* weights, bool periodic)
{
......@@ -240,28 +249,34 @@ complex float* compute_psf(unsigned int N, const long img_dims[N], const long tr
assert(1 == trj2_dims[6]);
ksp2_dims[N - 1] = trj2_dims[5];
trj2_dims[N - 1] = trj2_dims[5];
trj2_dims[5] = 1;
trj2_dims[5] = 1; // FIXME copy?
}
struct nufft_conf_s conf = nufft_conf_defaults;
conf.periodic = periodic;
conf.toeplitz = false; // avoid infinite loop
complex float* ones = md_alloc(N, ksp2_dims, CFL_SIZE);
debug_printf(DP_DEBUG2, "nufft kernel dims: ");
debug_print_dims(DP_DEBUG2, N, ksp2_dims);
compute_kern(N, ksp2_dims, ones, bas2_dims, basis, wgh2_dims, weights);
complex float* psft = md_alloc(N, img2_dims, CFL_SIZE);
debug_printf(DP_DEBUG2, "nufft psf dims: ");
debug_print_dims(DP_DEBUG2, N, img2_dims);
debug_printf(DP_DEBUG2, "nufft traj dims: ");
debug_print_dims(DP_DEBUG2, N, trj2_dims);
long pos[N];
for (int i = 0; i < N; i++)
pos[i] = 0;
#if 0
complex float* ones = md_alloc(N, ksp2_dims, CFL_SIZE);
compute_kern(N, ~0u, pos, ksp2_dims, ones, bas2_dims, basis, wgh2_dims, weights);
complex float* psft = md_alloc(N, img2_dims, CFL_SIZE);
struct linop_s* op2 = nufft_create(N, ksp2_dims, img2_dims, trj2_dims, traj, NULL, conf);
......@@ -270,6 +285,32 @@ complex float* compute_psf(unsigned int N, const long img_dims[N], const long tr
linop_free(op2);
md_free(ones);
#else
complex float* psft = md_calloc(N, img2_dims, CFL_SIZE);
long trj2_strs[N];
md_calc_strides(N, trj2_strs, trj2_dims, CFL_SIZE);
complex float* ones = md_alloc(N - 1, ksp2_dims, CFL_SIZE);
complex float* tmp = md_alloc(N - 1, img2_dims, CFL_SIZE);
for (long i = 0; i < trj2_dims[N - 1]; i++) {
pos[N - 1] = i;
compute_kern(N, ~(1 << (N - 1)), pos, ksp2_dims, ones, bas2_dims, basis, wgh2_dims, weights);
struct linop_s* op2 = nufft_create(N - 1, ksp2_dims, img2_dims, trj2_dims, (void*)traj + i * trj2_strs[N - 1], NULL, conf);
linop_adjoint_unchecked(op2, tmp, ones);
md_zadd(N - 1, img2_dims, psft, psft, tmp);
linop_free(op2);
}
md_free(ones);
md_free(tmp);
#endif
return psft;
}
......
......@@ -41,7 +41,7 @@ extern struct linop_s* nufft_create2(unsigned int N,
const complex float* basis,
struct nufft_conf_s conf);
extern _Complex float* compute_psf(unsigned int N,
extern _Complex float* compute_psf(int N,
const long img2_dims[__VLA(N)],
const long trj_dims[__VLA(N)],
const complex float* traj,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment