Commit 9bf7ec1c authored by Martin Uecker's avatar Martin Uecker

model-based reconstruction for look-locker experiments

parent f7671445
......@@ -169,6 +169,7 @@ MODULES_pics = -lgrecon -lsense -liter -llinops -lwavelet -llowrank -lnoncart
MODULES_sqpics = -lsense -liter -llinops -lwavelet -llowrank -lnoncart
MODULES_pocsense = -lsense -liter -llinops -lwavelet
MODULES_nlinv = -lnoir -liter -lnlops -llinops
MODULES_moba = -lmoba -lnoir -liter -lnlops -llinops -lwavelet
MODULES_bpsense = -lsense -lnoncart -liter -llinops -lwavelet
MODULES_itsense = -liter -llinops
MODULES_ecalib = -lcalib
......@@ -184,7 +185,7 @@ MODULES_rof = -liter -llinops
MODULES_tgv = -liter -llinops
MODULES_bench = -lwavelet -llinops
MODULES_phantom = -lsimu
MODULES_bart = -lbox -lgrecon -lsense -lnoir -liter -llinops -lwavelet -llowrank -lnoncart -lcalib -lsimu -lsake -ldfwavelet -lnlops
MODULES_bart = -lbox -lgrecon -lsense -lnoir -liter -llinops -lwavelet -llowrank -lnoncart -lcalib -lsimu -lsake -ldfwavelet -lnlops -lmoba
MODULES_sake = -lsake
MODULES_traj = -lnoncart
MODULES_wave = -liter -lwavelet -llinops -llowrank
......
......@@ -3,7 +3,7 @@
TBASE=show slice crop resize join transpose squeeze flatten zeros ones flip circshift extract repmat bitmask reshape version delta copy casorati vec poly index
TFLP=scale invert conj fmac saxpy sdot spow cpyphs creal carg normalize cdf97 pattern nrmse mip avg cabs zexp
TNUM=fft fftmod fftshift noise bench threshold conv rss filter mandelbrot wavelet window var std fftrot
TRECO=pics pocsense sqpics itsense nlinv nufft rof tgv sake wave lrmatrix estdims estshift estdelay wavepsf wshfl
TRECO=pics pocsense sqpics itsense nlinv moba nufft rof tgv sake wave lrmatrix estdims estshift estdelay wavepsf wshfl
TCALIB=ecalib ecaltwo caldir walsh cc ccapply calmat svd estvar whiten
TMRI=homodyne poisson twixread fakeksp
TSIM=phantom traj
......
mobasrcs := $(wildcard $(srcdir)/moba/*.c)
mobaobjs := $(mobasrcs:.c=.o)
.INTERMEDIATE: $(mobaobjs)
lib/libmoba.a: libmoba.a($(mobaobjs))
UTARGETS += test_moba
MODULES_test_moba += -lmoba -lnlops -llinops
/* Copyright 2013. The Regents of the University of California.
* Copyright 2019. Uecker Lab, University Medical Center Goettingen.
* All rights reserved. Use of this source code is governed by
* a BSD-style license which can be found in the LICENSE file.
*
* Authors: Xiaoqing Wang, Martin Uecker
*/
#include <stdbool.h>
#include <complex.h>
#include <math.h>
#include "num/multind.h"
#include "num/flpmath.h"
#include "num/fft.h"
#include "num/init.h"
#include "misc/mri.h"
#include "misc/misc.h"
#include "misc/mmio.h"
#include "misc/utils.h"
#include "misc/opts.h"
#include "misc/debug.h"
#include "noir/recon.h"
#include "moba/recon_T1.h"
static const char usage_str[] = "<kspace> <TI/TE> <output> [<sensitivities>]";
static const char help_str[] = "Model-based nonlinear inverse reconstruction\n";
int main_moba(int argc, char* argv[])
{
double start_time = timestamp();
float restrict_fov = -1.;
const char* psf = NULL;
struct noir_conf_s conf = noir_defaults;
bool out_sens = false;
bool usegpu = false;
bool unused = false;
enum mdb_t { MDB_T1 } mode = { MDB_T1 };
const struct opt_s opts[] = {
OPT_SELECT('L', enum mdb_t, &mode, MDB_T1, "T1 mapping using model-based look-locker"),
OPT_UINT('i', &conf.iter, "iter", "Number of Newton steps"),
OPT_FLOAT('R', &conf.redu, "", "(reduction factor)"),
OPT_FLOAT('j', &conf.alpha_min, "", "Minimum regu. parameter"),
OPT_INT('d', &debug_level, "level", "Debug level"),
OPT_SET('N', &unused, "(normalize)"), // no-op
OPT_FLOAT('f', &restrict_fov, "FOV", ""),
OPT_STRING('p', &psf, "PSF", ""),
OPT_SET('g', &usegpu, "use gpu"),
};
cmdline(&argc, argv, 2, 4, usage_str, help_str, ARRAY_SIZE(opts), opts);
if (5 == argc)
out_sens = true;
num_init();
long ksp_dims[DIMS];
complex float* kspace_data = load_cfl(argv[1], DIMS, ksp_dims);
long TI_dims[DIMS];
complex float* TI = load_cfl(argv[2], DIMS, TI_dims);
assert(TI_dims[TE_DIM] == ksp_dims[TE_DIM]);
assert(1 == ksp_dims[MAPS_DIM]);
long dims[DIMS];
md_copy_dims(DIMS, dims, ksp_dims);
long img_dims[DIMS];
md_select_dims(DIMS, FFT_FLAGS|MAPS_FLAG|COEFF_FLAG|SLICE_FLAG|TIME2_FLAG, img_dims, dims);
img_dims[COEFF_DIM] = 3;
long img_strs[DIMS];
md_calc_strides(DIMS, img_strs, img_dims, CFL_SIZE);
long single_map_dims[DIMS];
md_select_dims(DIMS, FFT_FLAGS|MAPS_FLAG|SLICE_FLAG|TIME2_FLAG, single_map_dims, dims);
long single_map_strs[DIMS];
md_calc_strides(DIMS, single_map_strs, single_map_dims, CFL_SIZE);
long coil_dims[DIMS];
md_select_dims(DIMS, FFT_FLAGS|COIL_FLAG|MAPS_FLAG|SLICE_FLAG|TIME2_FLAG, coil_dims, dims);
long coil_strs[DIMS];
md_calc_strides(DIMS, coil_strs, coil_dims, CFL_SIZE);
complex float* img = create_cfl(argv[3], DIMS, img_dims);
complex float* single_map = create_cfl("", DIMS, single_map_dims);
long msk_dims[DIMS];
md_select_dims(DIMS, FFT_FLAGS, msk_dims, dims);
long msk_strs[DIMS];
md_calc_strides(DIMS, msk_strs, msk_dims, CFL_SIZE);
complex float* mask = NULL;
complex float* norm = md_alloc(DIMS, img_dims, CFL_SIZE);
complex float* sens = (out_sens ? create_cfl : anon_cfl)(out_sens ? argv[4] : "", DIMS, coil_dims);
md_zfill(DIMS, img_dims, img, 1.0);
md_clear(DIMS, coil_dims, sens, CFL_SIZE);
complex float* pattern = NULL;
long pat_dims[DIMS];
if (NULL != psf) {
complex float* tmp_psf =load_cfl(psf, DIMS, pat_dims);
pattern = anon_cfl("", DIMS, pat_dims);
md_copy(DIMS, pat_dims, pattern, tmp_psf, CFL_SIZE);
unmap_cfl(DIMS, pat_dims, tmp_psf);
// FIXME: check compatibility
if (-1 == restrict_fov)
restrict_fov = 0.5;
conf.noncart = true;
} else {
md_copy_dims(DIMS, pat_dims, img_dims);
pattern = anon_cfl("", DIMS, pat_dims);
estimate_pattern(DIMS, ksp_dims, COIL_FLAG, pattern, kspace_data);
}
double scaling = 5000. / md_znorm(DIMS, ksp_dims, kspace_data);
double scaling_psf = 1000. / md_znorm(DIMS, pat_dims, pattern);
debug_printf(DP_INFO, "Scaling: %f\n", scaling);
md_zsmul(DIMS, ksp_dims, kspace_data, kspace_data, scaling);
debug_printf(DP_INFO, "Scaling_psf: %f\n", scaling_psf);
md_zsmul(DIMS, pat_dims, pattern, pattern, scaling_psf);
if (-1. == restrict_fov) {
mask = md_alloc(DIMS, msk_dims, CFL_SIZE);
md_zfill(DIMS, msk_dims, mask, 1.);
} else {
float restrict_dims[DIMS] = { [0 ... DIMS - 1] = 1. };
restrict_dims[0] = restrict_fov;
restrict_dims[1] = restrict_fov;
restrict_dims[2] = restrict_fov;
mask = compute_mask(DIMS, msk_dims, restrict_dims);
//md_zsmul2(DIMS, img_dims, img_strs, img, msk_strs, mask ,1.0);
md_zmul2(DIMS, img_dims, img_strs, img, img_strs, img, msk_strs, mask);
// Choose a different initial guess for R1*
long pos[DIMS];
for (int i = 0; i < (int)DIMS; i++)
pos[i] = 0;
pos[COEFF_DIM] = 2;
md_copy_block(DIMS, pos, single_map_dims, single_map, img_dims, img, CFL_SIZE);
md_zsmul2(DIMS, single_map_dims, single_map_strs, single_map, single_map_strs, single_map, 1.5);
md_copy_block(DIMS, pos, img_dims, img, single_map_dims, single_map, CFL_SIZE);
}
// conf.alpha = 0.1;
#ifdef USE_CUDA
if (usegpu) {
complex float* kspace_gpu = md_alloc_gpu(DIMS, ksp_dims, CFL_SIZE);
md_copy(DIMS, ksp_dims, kspace_gpu, kspace_data, CFL_SIZE);
complex float* TI_gpu = md_alloc_gpu(DIMS, TI_dims, CFL_SIZE);
md_copy(DIMS, TI_dims, TI_gpu, TI, CFL_SIZE);
switch (mode) {
case MDB_T1:
T1_recon(&conf, dims, img, sens, pattern, mask, TI_gpu, kspace_gpu, usegpu);
break;
};
md_free(kspace_gpu);
md_free(TI_gpu);
} else
#endif
switch (mode) {
case MDB_T1:
T1_recon(&conf, dims, img, sens, pattern, mask, TI, kspace_data, usegpu);
break;
};
md_free(norm);
md_free(mask);
unmap_cfl(DIMS, coil_dims, sens);
unmap_cfl(DIMS, pat_dims, pattern);
unmap_cfl(DIMS, img_dims, img);
unmap_cfl(DIMS, single_map_dims, single_map);
unmap_cfl(DIMS, ksp_dims, kspace_data);
unmap_cfl(DIMS, TI_dims, TI);
double recosecs = timestamp() - start_time;
debug_printf(DP_DEBUG2, "Total Time: %.2f s\n", recosecs);
exit(0);
}
/* Copyright 2019. Uecker Lab, University Medical Center Goettingen.
* All rights reserved. Use of this source code is governed by
* a BSD-style license which can be found in the LICENSE file.
*
* Authors: Xiaoqing Wang, Martin Uecker
*/
#include <complex.h>
#include "misc/types.h"
#include "misc/misc.h"
#include "misc/mri.h"
#include "misc/debug.h"
#include "num/multind.h"
#include "num/flpmath.h"
#include "nlops/nlop.h"
#include "T1fun.h"
//#define general
//#define mphase
struct T1_s {
INTERFACE(nlop_data_t);
int N;
const long* map_dims;
const long* TI_dims;
const long* in_dims;
const long* out_dims;
const long* map_strs;
const long* TI_strs;
const long* in_strs;
const long* out_strs;
// Parameter maps
complex float* Mss;
complex float* M0;
complex float* R1s;
complex float* tmp_map;
complex float* tmp_ones;
complex float* tmp_exp;
complex float* tmp_dMss;
complex float* tmp_dM0;
complex float* tmp_dR1s;
complex float* TI;
float scaling_M0;
float scaling_R1s;
};
DEF_TYPEID(T1_s);
// Calculate Model: Mss - (Mss + M0) * exp(-t.*R1s)
static void T1_fun(const nlop_data_t* _data, complex float* dst, const complex float* src)
{
struct T1_s* data = CAST_DOWN(T1_s, _data);
long pos[data->N];
for (int i = 0; i < data->N; i++)
pos[i] = 0;
// Mss
pos[COEFF_DIM] = 0;
md_copy_block(data->N, pos, data->map_dims, data->Mss, data->in_dims, src, CFL_SIZE);
// M0
pos[COEFF_DIM] = 1;
md_copy_block(data->N, pos, data->map_dims, data->M0, data->in_dims, src, CFL_SIZE);
// R1s
pos[COEFF_DIM] = 2;
md_copy_block(data->N, pos, data->map_dims, data->R1s, data->in_dims, src, CFL_SIZE);
// -1*scaling_R1s.*R1s
md_zsmul2(data->N, data->map_dims, data->map_strs, data->tmp_map, data->map_strs, data->R1s, -1.0*data->scaling_R1s);
// exp(-t.*scaling_R1s*R1s):
//#ifdef general
// md_zmul2(data->N, data->out_dims, data->out_strs, data->tmp_exp, data->map_strs, data->tmp_map, data->TI_strs, data->TI);
#ifdef mphase
long map_no_time2_dims[DIMS];
md_select_dims(DIMS, ~TIME2_FLAG, map_no_time2_dims, data->map_dims);
long map_no_time2_strs[DIMS];
md_calc_strides(DIMS, map_no_time2_strs, map_no_time2_dims, CFL_SIZE);
for (int w = 0; w < (data->TI_dims[11]); w++)
for(int k = 0; k < (data->TI_dims[5]); k++) {
debug_printf(DP_DEBUG2, "\tTI: %f\n", creal(data->TI[k + data->TI_dims[5]*w]));
md_zsmul2(data->N, map_no_time2_dims, map_no_time2_strs, (void*)data->tmp_exp + data->out_strs[5] * k + data->out_strs[11] * w,
map_no_time2_strs, (void*)data->tmp_map + data->map_strs[11] * w, data->TI[k + data->TI_dims[5]*w]);
}
#else
for(int k=0; k < (data->TI_dims[5]); k++)
md_zsmul2(data->N, data->map_dims, data->out_strs, (void*)data->tmp_exp + data->out_strs[5] * k, data->map_strs, (void*)data->tmp_map, data->TI[k]);
#endif
long img_dims[data->N];
md_select_dims(data->N, FFT_FLAGS, img_dims, data->map_dims);
md_zexp(data->N, data->out_dims, data->tmp_exp, data->tmp_exp);
// scaling_M0.*M0
md_zsmul2(data->N, data->map_dims, data->map_strs, data->tmp_map, data->map_strs, data->M0, data->scaling_M0);
// Mss + scaling_M0*M0
md_zadd(data->N, data->map_dims, data->tmp_map, data->Mss, data->tmp_map);
// (Mss + scaling_M0*M0).*exp(-t.*scaling_R1s*R1s)
md_zmul2(data->N, data->out_dims, data->out_strs, dst, data->map_strs, data->tmp_map, data->out_strs, data->tmp_exp);
// Mss -(Mss + scaling_M0*M0).*exp(-t.*scaling_R1s*R1s)
md_zsub2(data->N, data->out_dims, data->out_strs, dst, data->map_strs, data->Mss, data->out_strs, dst);
// Calculating derivatives
// M0' = -scaling_M0.*exp(-t.*scaling_R1s.*R1s)
md_zsmul(data->N, data->out_dims, data->tmp_dM0, data->tmp_exp, -data->scaling_M0);
// Mss' = 1 - exp(-t.*scaling_R1s.*R1s)
md_zfill(data->N, data->map_dims, data->tmp_ones, 1.0);
md_zsub2(data->N, data->out_dims, data->out_strs, data->tmp_dMss, data->map_strs, data->tmp_ones, data->out_strs, data->tmp_exp);
// t*exp(-t.*scaling_R1s*R1s):
//#ifdef general
//md_zmul2(data->N, data->out_dims, data->out_strs, data->tmp_exp, data->out_strs, data->tmp_exp, data->TI_strs, data->TI);
#ifdef mphase
for (int s=0; s < data->out_dims[13]; s++)
for (int w=0; w < data->TI_dims[11]; w++)
for(int k=0; k < data->TI_dims[5]; k++)
md_zsmul(data->N, img_dims, (void*)data->tmp_exp + data->out_strs[5] * k + data->out_strs[11] * w + data->out_strs[13] * s,
(void*)data->tmp_exp + data->out_strs[5] * k + data->out_strs[11] * w + data->out_strs[13] * s, data->TI[k + data->TI_dims[5] * w]);
#else
for (int s=0; s < data->out_dims[13]; s++)
for(int k=0; k < data->TI_dims[5]; k++)
//debug_printf(DP_DEBUG2, "\tTI: %f\n", creal(data->TI[k]));
md_zsmul(data->N, img_dims, (void*)data->tmp_exp + data->out_strs[5] * k + data->out_strs[13] * s,
(void*)data->tmp_exp + data->out_strs[5] * k + data->out_strs[13] * s, data->TI[k]);
#endif
// scaling_M0:*exp(-t.*scaling_R1s.*R1s).*t
md_zsmul(data->N, data->out_dims, data->tmp_exp, data->tmp_exp, data->scaling_M0);
// scaling_M0.*M0
md_zsmul2(data->N, data->map_dims, data->map_strs, data->tmp_map, data->map_strs, data->M0, data->scaling_M0);
// Mss + scaling_M0*M0
md_zadd(data->N, data->map_dims, data->tmp_ones, data->Mss, data->tmp_map);
// R1s' = (Mss + scaling_M0*M0) * scaling_M0.*exp(-t.*scaling_R1s.*R1s) * t
md_zmul2(data->N, data->out_dims, data->out_strs, data->tmp_dR1s, data->map_strs, data->tmp_ones, data->out_strs, data->tmp_exp);
}
static void T1_der(const nlop_data_t* _data, complex float* dst, const complex float* src)
{
struct T1_s* data = CAST_DOWN(T1_s, _data);
long pos[data->N];
for (int i = 0; i < data->N; i++)
pos[i] = 0;
// tmp = dM0
pos[COEFF_DIM] = 1;
md_copy_block(data->N, pos, data->map_dims, data->tmp_map, data->in_dims, src, CFL_SIZE);
//const complex float* tmp_M0 = (const void*)src + md_calc_offset(data->N, data->in_strs, pos);
// dst = M0' * dM0
md_zmul2(data->N, data->out_dims, data->out_strs, dst, data->map_strs, data->tmp_map, data->out_strs, data->tmp_dM0);
// tmp = dMss
pos[COEFF_DIM] = 0;
md_copy_block(data->N, pos, data->map_dims, data->tmp_map, data->in_dims, src, CFL_SIZE);
//const complex float* tmp_Mss = (const void*)src + md_calc_offset(data->N, data->in_strs, pos);
// dst = dst + dMss * Mss'
md_zfmac2(data->N, data->out_dims, data->out_strs, dst, data->map_strs, data->tmp_map, data->out_strs, data->tmp_dMss);
// tmp = dR1s
pos[COEFF_DIM] = 2;
md_copy_block(data->N, pos, data->map_dims, data->tmp_map, data->in_dims, src, CFL_SIZE);
//const complex float* tmp_R1s = (const void*)src + md_calc_offset(data->N, data->in_strs, pos);
// dst = dst + dR1s * R1s'
md_zfmac2(data->N, data->out_dims, data->out_strs, dst, data->map_strs, data->tmp_map, data->out_strs, data->tmp_dR1s);
}
static void T1_adj(const nlop_data_t* _data, complex float* dst, const complex float* src)
{
struct T1_s* data = CAST_DOWN(T1_s, _data);
long pos[data->N];
for (int i = 0; i < data->N; i++)
pos[i] = 0;
// sum (conj(M0') * src, t)
md_clear(data->N, data->map_dims, data->tmp_map, CFL_SIZE);
md_zfmacc2(data->N, data->out_dims, data->map_strs, data->tmp_map, data->out_strs, src, data->out_strs, data->tmp_dM0);
// dst[1] = sum (conj(M0') * src, t)
pos[COEFF_DIM] = 1;
md_copy_block(data->N, pos, data->in_dims, dst, data->map_dims, data->tmp_map, CFL_SIZE);
// sum (conj(Mss') * src, t)
md_clear(data->N, data->map_dims, data->tmp_map, CFL_SIZE);
md_zfmacc2(data->N, data->out_dims, data->map_strs, data->tmp_map, data->out_strs, src, data->out_strs, data->tmp_dMss);
// dst[0] = sum (conj(Mss') * src, t)
pos[COEFF_DIM] = 0;
md_copy_block(data->N, pos, data->in_dims, dst, data->map_dims, data->tmp_map, CFL_SIZE);
// sum (conj(R1s') * src, t)
md_clear(data->N, data->map_dims, data->tmp_map, CFL_SIZE);
md_zfmacc2(data->N, data->out_dims, data->map_strs, data->tmp_map, data->out_strs, src, data->out_strs, data->tmp_dR1s);
//md_zreal(data->N, data->map_dims, data->tmp_map, data->tmp_map);
// dst[2] = sum (conj(R1s') * src, t)
pos[COEFF_DIM] = 2;
md_copy_block(data->N, pos, data->in_dims, dst, data->map_dims, data->tmp_map, CFL_SIZE);
}
static void T1_del(const nlop_data_t* _data)
{
struct T1_s* data = CAST_DOWN(T1_s, _data);
md_free(data->Mss);
md_free(data->M0);
md_free(data->R1s);
md_free(data->TI);
md_free(data->tmp_map);
md_free(data->tmp_ones);
md_free(data->tmp_exp);
md_free(data->tmp_dM0);
md_free(data->tmp_dMss);
md_free(data->tmp_dR1s);
xfree(data->map_dims);
xfree(data->TI_dims);
xfree(data->in_dims);
xfree(data->out_dims);
xfree(data->map_strs);
xfree(data->TI_strs);
xfree(data->in_strs);
xfree(data->out_strs);
xfree(data);
}
struct nlop_s* nlop_T1_create(int N, const long map_dims[N], const long out_dims[N], const long in_dims[N], const long TI_dims[N], const complex float* TI, bool use_gpu)
{
#ifdef USE_CUDA
md_alloc_fun_t my_alloc = use_gpu ? md_alloc_gpu : md_alloc;
#else
assert(!use_gpu);
md_alloc_fun_t my_alloc = md_alloc;
#endif
PTR_ALLOC(struct T1_s, data);
SET_TYPEID(T1_s, data);
PTR_ALLOC(long[N], ndims);
md_copy_dims(N, *ndims, map_dims);
data->map_dims = *PTR_PASS(ndims);
PTR_ALLOC(long[N], nodims);
md_copy_dims(N, *nodims, out_dims);
data->out_dims = *PTR_PASS(nodims);
PTR_ALLOC(long[N], nidims);
md_copy_dims(N, *nidims, in_dims);
data->in_dims = *PTR_PASS(nidims);
PTR_ALLOC(long[N], ntidims);
md_copy_dims(N, *ntidims, TI_dims);
data->TI_dims = *PTR_PASS(ntidims);
PTR_ALLOC(long[N], nmstr);
md_calc_strides(N, *nmstr, map_dims, CFL_SIZE);
data->map_strs = *PTR_PASS(nmstr);
PTR_ALLOC(long[N], nostr);
md_calc_strides(N, *nostr, out_dims, CFL_SIZE);
data->out_strs = *PTR_PASS(nostr);
PTR_ALLOC(long[N], nistr);
md_calc_strides(N, *nistr, in_dims, CFL_SIZE);
data->in_strs = *PTR_PASS(nistr);
PTR_ALLOC(long[N], ntistr);
md_calc_strides(N, *ntistr, TI_dims, CFL_SIZE);
data->TI_strs = *PTR_PASS(ntistr);
data->N = N;
data->Mss = my_alloc(N, map_dims, CFL_SIZE);
data->M0 = my_alloc(N, map_dims, CFL_SIZE);
data->R1s = my_alloc(N, map_dims, CFL_SIZE);
data->tmp_map = my_alloc(N, map_dims, CFL_SIZE);
data->tmp_ones = my_alloc(N, map_dims, CFL_SIZE);
data->tmp_exp = my_alloc(N, out_dims, CFL_SIZE);
data->tmp_dM0 = my_alloc(N, out_dims, CFL_SIZE);
data->tmp_dMss = my_alloc(N, out_dims, CFL_SIZE);
data->tmp_dR1s = my_alloc(N, out_dims, CFL_SIZE);
#ifdef general
data->TI = my_alloc(N, TI_dims, CFL_SIZE);
#else
data->TI = md_alloc(N, TI_dims, CFL_SIZE);
#endif
md_copy(N, TI_dims, data->TI, TI, CFL_SIZE);
data->scaling_M0 = 2.0;
data->scaling_R1s = 1.0;
return nlop_create(N, out_dims, N, in_dims, CAST_UP(PTR_PASS(data)), T1_fun, T1_der, T1_adj, NULL, NULL, T1_del);
}
struct nlop_s;
struct noir_model_conf_s;
extern struct nlop_s* nlop_T1_create(int N, const long map_dims[N], const long out_dims[N], const long in_dims[N],
const long TI_dims[N], const complex float* TI, bool use_gpu);
/* Copyright 2013-2014. The Regents of the University of California.
* Copyright 2019. Uecker Lab, University Medical Center Goettingen.
* All rights reserved. Use of this source code is governed by
* a BSD-style license which can be found in the LICENSE file.
*
* Authors: Xiaoqing Wang, Martin Uecker
*/
#include <assert.h>
#include <stdbool.h>
#include <math.h>
#include <stdio.h>
#include "misc/types.h"
#include "misc/mri.h"
#include "misc/debug.h"
#include "misc/misc.h"
#include "num/multind.h"
#include "num/flpmath.h"
#include "num/ops_p.h"
#include "num/rand.h"
#include "num/ops.h"
#include "num/iovec.h"
#include "wavelet/wavthresh.h"
#include "nlops/nlop.h"
#include "iter/prox.h"
#include "iter/vec.h"
#include "iter/italgos.h"
#include "iter/iter3.h"
#include "iter/iter2.h"
#include "iter_l1.h"
struct T1inv_s {
INTERFACE(iter_op_data);
const struct nlop_s* nlop;
const struct iter3_irgnm_conf* conf;
long size_x;
long size_y;
float alpha;
const long* dims;
bool first_iter;
int outer_iter;
const struct operator_p_s* prox1;
const struct operator_p_s* prox2;
};
DEF_TYPEID(T1inv_s);
static void normal_fista(iter_op_data* _data, float* dst, const float* src)
{
auto data = CAST_DOWN(T1inv_s, _data);
float* tmp = md_alloc_sameplace(1, MD_DIMS(data->size_y), FL_SIZE, src);
linop_forward_unchecked(nlop_get_derivative(data->nlop, 0, 0), (complex float*)tmp, (const complex float*)src);
linop_adjoint_unchecked(nlop_get_derivative(data->nlop, 0, 0), (complex float*)dst, (const complex float*)tmp);
md_free(tmp);
long res = data->dims[0];
long parameters = data->dims[COEFF_DIM];
long coils = data->dims[COIL_DIM];
md_axpy(1, MD_DIMS(data->size_x * coils / (coils + parameters)),
dst + res * res * 2 * parameters,
data->alpha,
src + res * res * 2 * parameters);
}
static void pos_value(iter_op_data* _data, float* dst, const float* src)
{
auto data = CAST_DOWN(T1inv_s, _data);
long res = data->dims[0];
long parameters = data->dims[COEFF_DIM];
long dims1[DIMS];
md_select_dims(DIMS, FFT_FLAGS, dims1, data->dims);
md_zsmax(DIMS, dims1, (_Complex float*)dst + (parameters - 1) * res * res,
(const _Complex float*)src + (parameters - 1) * res * res, 0.1);
}
static void combined_prox(iter_op_data* _data, float rho, float* dst, const float* src)
{
struct T1inv_s* data = CAST_DOWN(T1inv_s, _data);
if (data->first_iter) {
data->first_iter = false;
} else {
pos_value(_data, dst, src);
}
operator_p_apply_unchecked(data->prox2, rho, (_Complex float*)dst, (const _Complex float*)dst);