Commit 26225a60 authored by Martin Uecker's avatar Martin Uecker

generalize cuda fft a bit

parent 3e84c221
/* Copyright 2013, 2015. The Regents of the University of California.
* Copyright 2019. Martin Uecker.
* All rights reserved. Use of this source code is governed by
* a BSD-style license which can be found in the LICENSE file.
*
* Authors:
* 2012-2013, 2015 Martin Uecker <uecker@eecs.berkeley.edu>
* 2012-2019 Martin Uecker <martin.uecker@med.uni-goettingen.de>
*
*
* Internal interface to the CUFFT library used in fft.c.
......@@ -30,6 +31,7 @@
struct fft_cuda_plan_s {
cufftHandle cufft;
struct fft_cuda_plan_s* chain;
bool backwards;
......@@ -49,7 +51,7 @@ struct iovec {
struct fft_cuda_plan_s* fft_cuda_plan(unsigned int D, const long dimensions[D], unsigned long flags, const long ostrides[D], const long istrides[D], bool backwards)
static struct fft_cuda_plan_s* fft_cuda_plan0(unsigned int D, const long dimensions[D], unsigned long flags, const long ostrides[D], const long istrides[D], bool backwards)
{
PTR_ALLOC(struct fft_cuda_plan_s, plan);
unsigned int N = D;
......@@ -58,6 +60,7 @@ struct fft_cuda_plan_s* fft_cuda_plan(unsigned int D, const long dimensions[D],
plan->odist = 0;
plan->idist = 0;
plan->backwards = backwards;
plan->chain = NULL;
struct iovec dims[N];
struct iovec hmdims[N];
......@@ -107,8 +110,8 @@ struct fft_cuda_plan_s* fft_cuda_plan(unsigned int D, const long dimensions[D],
for (unsigned int i = 0; i < k; i++) {
assert(dims[i].is == lis);
assert(dims[i].os == los);
// assert(dims[i].is == lis);
// assert(dims[i].os == los);
cudims[k - 1 - i] = dims[i].n;
cuiemb[k - 1 - i] = dims[i].n;
......@@ -183,9 +186,42 @@ errout:
}
struct fft_cuda_plan_s* fft_cuda_plan(unsigned int D, const long dimensions[D], unsigned long flags, const long ostrides[D], const long istrides[D], bool backwards)
{
struct fft_cuda_plan_s* plan = fft_cuda_plan0(D, dimensions, flags, ostrides, istrides, backwards);
if (NULL != plan)
return plan;
int lsb = ffs(flags) - 1;
if (flags & lsb) { // FIXME: this couldbe better...
struct fft_cuda_plan_s* plan = fft_cuda_plan0(D, dimensions, lsb, ostrides, istrides, backwards);
if (NULL == plan)
return NULL;
plan->chain = fft_cuda_plan(D, dimensions, MD_CLEAR(flags, lsb), ostrides, ostrides, backwards);
if (NULL == plan->chain) {
fft_cuda_free_plan(plan);
return NULL;
}
return plan;
}
return NULL;
}
void fft_cuda_free_plan(struct fft_cuda_plan_s* cuplan)
{
if (NULL != cuplan->chain)
fft_cuda_free_plan(cuplan->chain);
cufftDestroy(cuplan->cufft);
xfree(cuplan);
}
......@@ -207,7 +243,9 @@ void fft_cuda_exec(struct fft_cuda_plan_s* cuplan, complex float* dst, const com
(!cuplan->backwards) ? CUFFT_FORWARD : CUFFT_INVERSE)))
error("CUFFT: %d\n", err);
}
}
if (NULL != cuplan->chain)
fft_cuda_exec(cuplan->chain, dst, dst);
}
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment