Introduction

I’ll walk through my implementation of the Single Pass Downsampler (SPD) algorithm for Apple’s Metal API. SPD is an algorithm for generating mipmaps in a single compute shader pass, as the FidelityFX SPD by AMD. Full source code at the end of the post.

What is a Single Pass Downsampler?

Traditional mipmap generation requires multiple passes, where each mip level is computed in a separate dispatch with synchronization points between them. This means the GPU must wait for each level to complete, write results to device memory, and then read them back for the next. This results in both memory bandwidth costs and CPU-GPU synchronization overhead from multiple API calls. SPD eliminates this overhead by using threadgroup memory to keep intermediate downsampling results local to the compute threadgroup, allowing all mip levels to be generated without round trips to device memory. SPD can be generalized to be used for any texture, however I implemented it to generate a depth pyramid for occlusion culling.

The Dispatch Logic

Each threadgroup processes a region of 64×64 pixels and consists of 16×16 threads. Each thread processes a 4×4 block from the source, storing it to mip 0 and reducing it to a 2×2 block for mip 1. In the code I will refer to the 2D location of the thread as lid and the 2D location of a threadgroup as tgid.

SPD Threadgroup Organization

Figure 1: Threadgroup layout for the Single Pass Downsampler

Compute kernel

For the kernel we only need the source texture (depth texture from z-prepass in my case), the destination mip map chain texture, a global atomic counter which I will explain later and a struct of constants.

struct SPDConstants {
    uint2 srcSize;
    uint  mipCount;
    uint  numWorkgroups;
};

Mip 0 case

In the multi pass reduction we would usually have a special case or separate kernel to copy the source texture to the mip 0 of our destination texture. In the case of SPD, we need to sample anyway the source texture to do the first reduction to mip 1. Therefore we can write the values directly to mip 0 before the reduction.

kernel void depth_pyramid_spd(texture2d<float, access::read>        srcDepth        [[texture(0)]],
                              texture2d<float, access::read_write>  pyramid         [[texture(1)]],
                       device atomic_uint&                          globalCounter   [[buffer(0)]],
                     constant SPDConstants&                         constants       [[buffer(1)]],
                              uint2                                 lid             [[thread_position_in_threadgroup]],
                              uint2                                 tgid            [[threadgroup_position_in_grid]],
                              uint                                  lid_flat        [[thread_index_in_threadgroup]]) {
threadgroup float lds[32][32];

uint2 srcSize = constants.srcSize;
uint2 baseCoord = tgid * 64 + lid * 4;

// load 4x4 from source, copy to mip 0, reduce to 2x2 for mip 1

// row 0
float s00 = srcDepth.read(min(baseCoord + uint2(0, 0), srcSize - 1)).r;
float s10 = srcDepth.read(min(baseCoord + uint2(1, 0), srcSize - 1)).r;
float s20 = srcDepth.read(min(baseCoord + uint2(2, 0), srcSize - 1)).r;
float s30 = srcDepth.read(min(baseCoord + uint2(3, 0), srcSize - 1)).r;

// row 1
float s01 = srcDepth.read(min(baseCoord + uint2(0, 1), srcSize - 1)).r;
float s11 = srcDepth.read(min(baseCoord + uint2(1, 1), srcSize - 1)).r;
float s21 = srcDepth.read(min(baseCoord + uint2(2, 1), srcSize - 1)).r;
float s31 = srcDepth.read(min(baseCoord + uint2(3, 1), srcSize - 1)).r;

// row 2
float s02 = srcDepth.read(min(baseCoord + uint2(0, 2), srcSize - 1)).r;
float s12 = srcDepth.read(min(baseCoord + uint2(1, 2), srcSize - 1)).r;
float s22 = srcDepth.read(min(baseCoord + uint2(2, 2), srcSize - 1)).r;
float s32 = srcDepth.read(min(baseCoord + uint2(3, 2), srcSize - 1)).r;

// row 3
float s03 = srcDepth.read(min(baseCoord + uint2(0, 3), srcSize - 1)).r;
float s13 = srcDepth.read(min(baseCoord + uint2(1, 3), srcSize - 1)).r;
float s23 = srcDepth.read(min(baseCoord + uint2(2, 3), srcSize - 1)).r;
float s33 = srcDepth.read(min(baseCoord + uint2(3, 3), srcSize - 1)).r;

// write 4x4 block to mip 0 to get a full copy

if (all(baseCoord + uint2(0, 0) < srcSize)) pyramid.write(float4(s00, 0, 0, 1), baseCoord + uint2(0, 0), 0);
if (all(baseCoord + uint2(1, 0) < srcSize)) pyramid.write(float4(s10, 0, 0, 1), baseCoord + uint2(1, 0), 0);
if (all(baseCoord + uint2(2, 0) < srcSize)) pyramid.write(float4(s20, 0, 0, 1), baseCoord + uint2(2, 0), 0);
if (all(baseCoord + uint2(3, 0) < srcSize)) pyramid.write(float4(s30, 0, 0, 1), baseCoord + uint2(3, 0), 0);

if (all(baseCoord + uint2(0, 1) < srcSize)) pyramid.write(float4(s01, 0, 0, 1), baseCoord + uint2(0, 1), 0);
if (all(baseCoord + uint2(1, 1) < srcSize)) pyramid.write(float4(s11, 0, 0, 1), baseCoord + uint2(1, 1), 0);
if (all(baseCoord + uint2(2, 1) < srcSize)) pyramid.write(float4(s21, 0, 0, 1), baseCoord + uint2(2, 1), 0);
if (all(baseCoord + uint2(3, 1) < srcSize)) pyramid.write(float4(s31, 0, 0, 1), baseCoord + uint2(3, 1), 0);

if (all(baseCoord + uint2(0, 2) < srcSize)) pyramid.write(float4(s02, 0, 0, 1), baseCoord + uint2(0, 2), 0);
if (all(baseCoord + uint2(1, 2) < srcSize)) pyramid.write(float4(s12, 0, 0, 1), baseCoord + uint2(1, 2), 0);
if (all(baseCoord + uint2(2, 2) < srcSize)) pyramid.write(float4(s22, 0, 0, 1), baseCoord + uint2(2, 2), 0);
if (all(baseCoord + uint2(3, 2) < srcSize)) pyramid.write(float4(s32, 0, 0, 1), baseCoord + uint2(3, 2), 0);

if (all(baseCoord + uint2(0, 3) < srcSize)) pyramid.write(float4(s03, 0, 0, 1), baseCoord + uint2(0, 3), 0);
if (all(baseCoord + uint2(1, 3) < srcSize)) pyramid.write(float4(s13, 0, 0, 1), baseCoord + uint2(1, 3), 0);
if (all(baseCoord + uint2(2, 3) < srcSize)) pyramid.write(float4(s23, 0, 0, 1), baseCoord + uint2(2, 3), 0);
if (all(baseCoord + uint2(3, 3) < srcSize)) pyramid.write(float4(s33, 0, 0, 1), baseCoord + uint2(3, 3), 0);

The local data share (lds) is 32×32 instead of 16×16 that AMD is using. For simplicity I didn’t use subgroup operations.

Reduction

In the case of a depth pyramid mip chain, we want to find the maximum depth of each 2×2 block and write it to the next mip level. In case of reverse-z you want to use minimum reduction.

float reduce2x2(float a, float b, float c, float d) {
    return max(max(a, b), max(c, d));
}

Each threadgroup has 64×64 pixels and we dispatch 16×16 threads. For the mip 1 level we have to perform 4 reductions per thread.

// reduce 4x4 to 2x2 for mip 1
float r0 = reduce2x2(s00, s10, s01, s11);
float r1 = reduce2x2(s20, s30, s21, s31);
float r2 = reduce2x2(s02, s12, s03, s13);
float r3 = reduce2x2(s22, s32, s23, s33);

uint2 ldsBase = lid * 2;
lds[ldsBase.y + 0][ldsBase.x + 0] = r0;
lds[ldsBase.y + 0][ldsBase.x + 1] = r1;
lds[ldsBase.y + 1][ldsBase.x + 0] = r2;
lds[ldsBase.y + 1][ldsBase.x + 1] = r3;

// mip 1: 64x64 to 32x32
uint2 mip1Base = tgid * 32 + ldsBase;
uint2 mip1Size = srcSize >> 1;

if (all(mip1Base + uint2(0, 0) < mip1Size)) pyramid.write(float4(r0, 0, 0, 1), mip1Base + uint2(0, 0), 1);
if (all(mip1Base + uint2(1, 0) < mip1Size)) pyramid.write(float4(r1, 0, 0, 1), mip1Base + uint2(1, 0), 1);
if (all(mip1Base + uint2(0, 1) < mip1Size)) pyramid.write(float4(r2, 0, 0, 1), mip1Base + uint2(0, 1), 1);
if (all(mip1Base + uint2(1, 1) < mip1Size)) pyramid.write(float4(r3, 0, 0, 1), mip1Base + uint2(1, 1), 1);

threadgroup_barrier(mem_flags::mem_threadgroup);

Since we reduced the 64×64 pixels to 32×32, they can now fit in the lds, to be used for the next mips. The ldsBase is the threads location times 2, since from now on we are processing a 2×2 block. We set a barrier to wait for all the threads to complete, before moving to the next mip level.

Reductions until 1×1

The reductions continue in the same fashion until the mip 6, where each threadgroup has reduced its original 64×64 block to 1 pixel. From mip 3 and after, we will be having idle threads, as they are not needed.

// mip 2: 32x32 to 16x16
float reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0],
                          lds[ldsBase.y + 0][ldsBase.x + 1],
                          lds[ldsBase.y + 1][ldsBase.x + 0],
                          lds[ldsBase.y + 1][ldsBase.x + 1]);

lds[lid.y][lid.x] = reduced;

uint2 mip2Coord = tgid * 16 + lid;
uint2 mip2Size = srcSize >> 2;
if (all(mip2Coord < mip2Size)) {
    pyramid.write(float4(reduced, 0, 0, 1), mip2Coord, 2);
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// mip 3: 16x16 to 8x8
if (all(lid < 8)) {
    reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0],
                        lds[ldsBase.y + 0][ldsBase.x + 1],
                        lds[ldsBase.y + 1][ldsBase.x + 0],
                        lds[ldsBase.y + 1][ldsBase.x + 1]);

    lds[lid.y][lid.x] = reduced;

    uint2 mip3Coord = tgid * 8 + lid;
    uint2 mip3Size = srcSize >> 3;
    if (all(mip3Coord < mip3Size)) {
        pyramid.write(float4(reduced, 0, 0, 1), mip3Coord, 3);
    }
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// ...
// ...

// mip 6: 2x2 to 1x1
if (lid_flat == 0) {
    reduced = reduce2x2(lds[0][0], lds[0][1], lds[1][0], lds[1][1]);

    uint2 mip6Coord = tgid;
    uint2 mip6Size = srcSize >> 6;
    if (all(mip6Coord < mip6Size)) {
        pyramid.write(float4(reduced, 0, 0, 1), mip6Coord, 6);
    }
}

threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup);

Understanding the Threadgroup Synchronization

Before we continue with the rest of the mips, let’s work through a concrete example.

In a scenario with resolution of 1920×1080, where the mip count is calculated as:

mipCount = static_cast<uint32_t>(ceil(log2(fmax(depthWidth, depthHeight))));

we would end up with 11 mip levels. So at the mip level 6, the 1920×1080 has become 30×17. Therefore there are 30×17 = 510 threadgroups that are writing to mip level 6. The idea that FidelityFX SPD presents, is that a thread of the threadgroup that finished last in the calculation of mip 6, will reduce the remaining of the mips. That’s why we also want a mem_flags::mem_device in the last barrier, to synchronize among the 510 threadgroups before writing to the rest of the mips.

Atomic Global Counter

The way that we can know which threadgroup finished last is by using a global atomic buffer.

if (lid_flat == 0) {
    uint completed = atomic_fetch_add_explicit(&globalCounter, 1, memory_order_relaxed) + 1;

    if (completed == constants.numWorkgroups) {
        atomic_store_explicit(&globalCounter, 0, memory_order_relaxed);
        // ...

We increment the global atomic (lid_flat is the flattened 1D index of the thread, derived from the 2D lid coordinate) and if it is equal to 510 in this case, then we continue reducing by reading directly the pixels of the previous mip.

        for (uint mip = 7; mip < constants.mipCount; mip++) {
            uint2 dstSize = max(srcSize >> mip, uint2(1, 1));
            uint2 srcMipSize = max(srcSize >> (mip - 1), uint2(1, 1));

            for (uint y = 0; y < dstSize.y; y++) {
                for (uint x = 0; x < dstSize.x; x++) {
                    uint2 readBase = uint2(x, y) * 2;
                    float r0 = pyramid.read(min(readBase + uint2(0, 0), srcMipSize - 1), mip - 1).r;
                    float r1 = pyramid.read(min(readBase + uint2(1, 0), srcMipSize - 1), mip - 1).r;
                    float r2 = pyramid.read(min(readBase + uint2(0, 1), srcMipSize - 1), mip - 1).r;
                    float r3 = pyramid.read(min(readBase + uint2(1, 1), srcMipSize - 1), mip - 1).r;
                    pyramid.write(float4(reduce2x2(r0, r1, r2, r3), 0, 0, 1), uint2(x, y), mip);
                }
            }
        }
    }
}

Since only 1 thread is working now, we don’t need any synchronization.

Source code

References