/* $Id: spm_mrf.c 4873 2012-08-30 19:06:26Z john $ */ /* (c) John Ashburner (2010) */ #include "mex.h" #include <math.h> #define MAXCLASSES 1024 static void mrf1(mwSize dm[], unsigned char q[], float p[], float G[], float w[], int code) { mwSize i0, i1, i2, k, m, n; float a[MAXCLASSES], e[MAXCLASSES], *p0 = NULL, *p1 = NULL; unsigned char *q0 = NULL, *q1 = NULL; int it; m = dm[0]*dm[1]*dm[2]; /* Use a red-black scheme, so the updates are for alternating voxels. Then do another pass to update the other half. A B A B A B B A B A B A A B A B A B B A B A B A Updates involve computing the number of neighbours of each type (stored in vector a), and using the connectivity matrix (G) to update with: q = (p.*exp(G'*a))/sum((p.*exp(G'*a)) */ for(it=0; it<2; it++) { mwSize i2start = it%2; for(i2=0; i2<dm[2]; i2++) /* Inferior -> Superior */ { mwSize i1start = (i2start == (i2%2)); for(i1=0; i1<dm[1]; i1++) /* Posterior -> Anterior */ { mwSize i0start = (i1start == (i1%2)); p1 = p + dm[0]*(i1+dm[1]*i2); q1 = q + dm[0]*(i1+dm[1]*i2); for(i0=i0start; i0<dm[0]; i0+=2) /* Left -> Right */ { float se; unsigned char *qq = NULL; /* Pointers to current voxel in first volume */ p0 = p1 + i0; q0 = q1 + i0; /* Initialise neighbour counts to zero */ for(k=0; k<dm[3]; k++) a[k] = 0.0; /* Count neighbours of each class */ if(i2>0) /* Inferior */ { qq = q0 - dm[0]*dm[1]; for(k=0; k<dm[3]; k++) a[k] += qq[k*m]*w[2]; } if(i2<dm[2]-1) /* Superior */ { qq = q0 + dm[0]*dm[1]; for(k=0; k<dm[3]; k++) a[k] += qq[k*m]*w[2]; } if(i1>0) /* Posterior */ { qq = q0 - dm[0]; for(k=0; k<dm[3]; k++) a[k] += qq[k*m]*w[1]; } if(i1<dm[1]-1) /* Anterior */ { qq = q0 + dm[0]; for(k=0; k<dm[3]; k++) a[k] += qq[k*m]*w[1]; } if(i0>0) /* Left */ { qq = q0 - 1; for(k=0; k<dm[3]; k++) a[k] += qq[k*m]*w[0]; } if(i0<dm[0]-1) /* Right */ { qq = q0 + 1; for(k=0; k<dm[3]; k++) a[k] += qq[k*m]*w[0]; } /* Responsibility data is uint8, so correct scaling. Note also that data is divided by 6 (the number of neighbours examined). */ for(k=0; k<dm[3]; k++) a[k]/=(255.0*6.0); if (code == 1) { /* Weights are in the form of a matrix, shared among all voxels. */ float *g; se = 0.0; for(k=0, g=G; k<dm[3]; k++) { e[k] = 0; for(n=0; n<dm[3]; n++, g++) e[k] += (*g)*a[n]; e[k] = exp((double)e[k])*p0[k*m]; se += e[k]; } } else if (code == 2) { /* Weights are assumed to be a diagonal matrix, so only the diagonal elements are passed. */ se = 0.0; for(k=0; k<dm[3]; k++) { e[k] = exp((double)(G[k]*a[k]))*p0[k*m]; se += e[k]; } } else if (code == 3) { /* Separate weights for each voxel, in the form of the full matrix (loads of memory). */ float *g; se = 0.0; g = G + i0+dm[0]*(i1+dm[1]*i2); for(k=0; k<dm[3]; k++) { e[k] = 0.0; for(n=0; n<dm[3]; n++, g+=m) e[k] += (*g)*a[n]; e[k] = exp((double)e[k])*p0[k*m]; se += e[k]; } } else if (code == 4) { /* Separate weight matrices for each voxel, where the matrices are assumed to be symmetric with zeros on the diagonal. For a 4x4 matrix, the elements are ordered as (2,1), (3,1), (4,1), (3,2), (4,2), (4,3). */ float *g; g = G + i0+dm[0]*(i1+dm[1]*i2); for(k=0; k<dm[3]; k++) e[k] = 0.0; for(k=0; k<dm[3]; k++) { for(n=k+1; n<dm[3]; n++, g+=m) { e[k] += (*g)*a[n]; e[n] += (*g)*a[k]; } } se = 0.0; for(k=0; k<dm[3]; k++) { e[k] = exp((double)e[k])*p0[k*m]; se += e[k]; } } else if (code == 5) { /* Separate weight matrices for each voxel, where the matrices are assumed to be symmetric with zeros on the diagonal. For a 4x4 matrix, the elements are ordered as (2,1), (3,1), (4,1), (3,2), (4,2), (4,3). The weight matrices are encoded as uint8, and their values need to be scaled by -0.0625 to bring them into a reasonable range. */ unsigned char *g; g = (unsigned char *)G + i0+dm[0]*(i1+dm[1]*i2); for(k=0; k<dm[3]; k++) e[k] = 0.0; for(k=0; k<dm[3]; k++) { for(n=k+1; n<dm[3]; n++, g+=m) { e[k] += ((float)(*g))*a[n]; e[n] += ((float)(*g))*a[k]; } } se = 0.0; for(k=0; k<dm[3]; k++) { e[k] = exp(-0.0625*e[k])*p0[k*m]; se += e[k]; } } /* Normalise responsibilities to sum to 1 and rescale for saving as uint8 data. */ se = 255.0/se; for(k=0; k<dm[3]; k++) q0[k*m] = (unsigned char)(e[k]*se+0.5); } } } } } void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { mwSize i; mwSize dm[4]; float *p = NULL, w[3]; unsigned char *q = NULL; float *G = NULL; int code=0; if (nrhs<3 || nrhs>4 || nlhs>1) mexErrMsgTxt("Incorrect usage"); if (!mxIsNumeric(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]) || !mxIsUint8(prhs[0])) mexErrMsgTxt("First arg must be numeric, real, full and uint8."); if (!mxIsNumeric(prhs[1]) || mxIsComplex(prhs[1]) || mxIsSparse(prhs[1]) || !mxIsSingle(prhs[1])) mexErrMsgTxt("Second arg must be numeric, real, full and single."); if (!mxIsNumeric(prhs[2]) || mxIsComplex(prhs[2]) || mxIsSparse(prhs[2])) mexErrMsgTxt("Third arg must be numeric, real and full."); if (mxGetNumberOfDimensions(prhs[0])!= mxGetNumberOfDimensions(prhs[1]) || mxGetNumberOfDimensions(prhs[0])>4) mexErrMsgTxt("First or second args have wrong number of dimensions."); for(i=0; i<mxGetNumberOfDimensions(prhs[0]); i++) dm[i] = mxGetDimensions(prhs[0])[i]; for(i=mxGetNumberOfDimensions(prhs[0]); i<4; i++) dm[i] = 1; if (dm[3]>MAXCLASSES) mexErrMsgTxt("Too many classes."); for(i=0; i<4; i++) if (mxGetDimensions(prhs[1])[i] != dm[i]) mexErrMsgTxt("First and second args have incompatible dimensions."); if (mxGetDimensions(prhs[2])[1] == 1) { code = 2; if (mxGetDimensions(prhs[2])[0] != dm[3]) mexErrMsgTxt("Third arg has incompatible dimensions."); if (!mxIsSingle(prhs[2])) mexErrMsgTxt("Third arg must be single."); } else if (mxGetNumberOfDimensions(prhs[2])==2) { code = 1; if (mxGetDimensions(prhs[2])[0] != dm[3] || mxGetDimensions(prhs[2])[1] != dm[3]) mexErrMsgTxt("Third arg has incompatible dimensions."); if (!mxIsSingle(prhs[2])) mexErrMsgTxt("Third arg must be single."); } else if (mxGetNumberOfDimensions(prhs[2])==5) { code = 3; for(i=0; i<4; i++) if (mxGetDimensions(prhs[2])[i] != dm[i]) mexErrMsgTxt("Third arg has incompatible dimensions."); if (mxGetDimensions(prhs[2])[4] != dm[3]) mexErrMsgTxt("Third arg has incompatible dimensions."); if (!mxIsSingle(prhs[2])) mexErrMsgTxt("Third arg must be single."); } else if (mxGetNumberOfDimensions(prhs[2])==4) { for(i=0; i<3; i++) if (mxGetDimensions(prhs[2])[i] != dm[i]) mexErrMsgTxt("Third arg has incompatible dimensions."); if (mxGetDimensions(prhs[2])[3] != (dm[3]*(dm[3]-1))/2) mexErrMsgTxt("Third arg has incompatible dimensions."); if (mxIsSingle(prhs[2])) code = 4; else if (mxIsUint8(prhs[0])) code = 5; else mexErrMsgTxt("Third arg must be either single or uint8."); } else mexErrMsgTxt("Third arg has incompatible dimensions."); p = (float *)mxGetData(prhs[1]); G = (float *)mxGetData(prhs[2]); if (nrhs>=4) { /* Adjustment for anisotropic voxel sizes. w should contain the square of each voxel size. */ if (!mxIsNumeric(prhs[3]) || mxIsComplex(prhs[3]) || mxIsSparse(prhs[3]) || !mxIsSingle(prhs[3])) mexErrMsgTxt("Fourth arg must be numeric, real, full and single."); if (mxGetNumberOfElements(prhs[3]) != 3) mexErrMsgTxt("Fourth arg must contain three elements."); for(i=0; i<3; i++) w[i] = ((float *)mxGetData(prhs[3]))[i]; } else { for(i=0; i<3; i++) w[i] = 1.0; } if (nlhs>0) { /* Copy input to output */ unsigned char *q0; plhs[0] = mxCreateNumericArray(4,dm, mxUINT8_CLASS, mxREAL); q0 = (unsigned char *)mxGetData(prhs[0]); q = (unsigned char *)mxGetData(plhs[0]); for(i=0; i<dm[0]*dm[1]*dm[2]*dm[3]; i++) q[i] = q0[i]; } else /* Note the nasty side effects - but it does save memory */ q = (unsigned char *)mxGetData(prhs[0]); mrf1(dm, q,p,G,w,code); }