ogl_beamforming

Ultrasound Beamforming Implemented with OpenGL
git clone anongit@rnpnr.xyz:ogl_beamforming.git
Log | Files | Refs | Feed | Submodules | README | LICENSE

Commit: aa5206908510552b6daeed83b2a5954014943133
Parent: fde09cde680e00eeea6ced6a452d9d16a9326209
Author: Randy Palamar
Date:   Sat, 25 Oct 2025 21:37:31 -0600

shaders/filter: use LDS to cooperatively load needed samples

also iq calculation can be performed once per sample instead of
each time a thread needs it.

this gives a ~28% speed up for a chirp filtering test.

Diffstat:
Mbeamformer.c | 8+++++---
Mbeamformer.meta | 1+
Mgenerated/beamformer.meta.c | 6++++--
Mshaders/filter.glsl | 65++++++++++++++++++++++++++++++++++++++---------------------------
4 files changed, 48 insertions(+), 32 deletions(-)

diff --git a/beamformer.c b/beamformer.c @@ -23,9 +23,9 @@ global f32 dt_for_frame; -#define FILTER_LOCAL_SIZE_X 64 -#define FILTER_LOCAL_SIZE_Y 1 -#define FILTER_LOCAL_SIZE_Z 1 +#define FILTER_LOCAL_SIZE_X 128 +#define FILTER_LOCAL_SIZE_Y 1 +#define FILTER_LOCAL_SIZE_Z 1 #define DECODE_LOCAL_SIZE_X 4 #define DECODE_LOCAL_SIZE_Y 1 @@ -573,6 +573,7 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb) bp->Filter.demodulation_frequency = pb->parameters.demodulation_frequency; bp->Filter.sampling_frequency = pb->parameters.sampling_frequency / 2; bp->Filter.decimation_rate = decimation_rate; + bp->Filter.sample_count = pb->parameters.sample_count; if (first) { bp->Filter.input_channel_stride = pb->parameters.raw_data_dimensions[0] / 2; @@ -601,6 +602,7 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb) bp->Filter.input_channel_stride = sample_count * pb->parameters.acquisition_count; bp->Filter.input_sample_stride = 1; bp->Filter.input_transmit_stride = sample_count; + bp->Filter.sample_count = sample_count; } /* TODO(rnp): filter may need a different dispatch layout */ diff --git a/beamformer.meta b/beamformer.meta @@ -90,6 +90,7 @@ @BakeInt(OutputChannelStride output_channel_stride ) @BakeInt(OutputSampleStride output_sample_stride ) @BakeInt(OutputTransmitStride output_transmit_stride) + @BakeInt(SampleCount sample_count ) @BakeInt(SamplingMode sampling_mode ) @BakeFloat(DemodulationFrequency demodulation_frequency) @BakeFloat(SamplingFrequency sampling_frequency ) diff --git a/generated/beamformer.meta.c b/generated/beamformer.meta.c @@ -110,6 +110,7 @@ typedef struct { u32 output_channel_stride; u32 output_sample_stride; u32 output_transmit_stride; + u32 sample_count; u32 sampling_mode; f32 demodulation_frequency; f32 sampling_frequency; @@ -304,6 +305,7 @@ read_only global s8 *beamformer_shader_bake_parameter_names[] = { s8_comp("OutputChannelStride"), s8_comp("OutputSampleStride"), s8_comp("OutputTransmitStride"), + s8_comp("SampleCount"), s8_comp("SamplingMode"), s8_comp("DemodulationFrequency"), s8_comp("SamplingFrequency"), @@ -330,7 +332,7 @@ read_only global s8 *beamformer_shader_bake_parameter_names[] = { read_only global u8 *beamformer_shader_bake_parameter_is_float[] = { (u8 []){0, 0, 0, 0, 0, 0, 0, 0}, - (u8 []){0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, + (u8 []){0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, (u8 []){0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1}, 0, 0, @@ -339,7 +341,7 @@ read_only global u8 *beamformer_shader_bake_parameter_is_float[] = { read_only global i32 beamformer_shader_bake_parameter_counts[] = { 8, - 11, + 12, 13, 0, 0, diff --git a/shaders/filter.glsl b/shaders/filter.glsl @@ -1,4 +1,6 @@ /* See LICENSE for license details. */ +/* TODO(rnp): bug: this won't filter RF data correctly */ +#define SAMPLE_TYPE vec2 #if DataKind == DataKind_Float32 #define DATA_TYPE vec2 #define RESULT_TYPE_CAST(v) (v) @@ -26,7 +28,7 @@ layout(std430, binding = 2) writeonly restrict buffer buffer_2 { }; layout(std430, binding = 3) readonly restrict buffer buffer_3 { - FILTER_TYPE filter_coefficients[]; + FILTER_TYPE filter_coefficients[FilterLength]; }; layout(r16i, binding = 1) readonly restrict uniform iimage1D channel_mapping; @@ -69,15 +71,16 @@ vec2 rotate_iq(vec2 iq, int index) } #endif -vec2 sample_rf(uint index) +SAMPLE_TYPE sample_rf(uint index) { - vec2 result = SAMPLE_TYPE_CAST(in_data[index]); + SAMPLE_TYPE result = SAMPLE_TYPE_CAST(in_data[index]); return result; } +shared SAMPLE_TYPE local_samples[FilterLength + gl_WorkGroupSize.x]; + void main() { - uint in_sample = gl_GlobalInvocationID.x * DecimationRate; uint out_sample = gl_GlobalInvocationID.x; uint channel = gl_GlobalInvocationID.y; uint transmit = gl_GlobalInvocationID.z; @@ -88,32 +91,40 @@ void main() OutputTransmitStride * transmit + OutputSampleStride * out_sample; - int target; - if (bool(MapChannels)) { - target = OutputChannelStride / OutputSampleStride; - } else { - target = OutputTransmitStride; - } - - if (out_sample < target) { - target *= DecimationRate; - - vec2 result = vec2(0); - int a_length = target; - int index = int(in_sample); + int thread_index = int(gl_LocalInvocationIndex); + int thread_count = int(gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z); + ///////////////////////// + // NOTE: sample caching + { + int min_sample = DecimationRate * int((gl_WorkGroupID.x + 0) * gl_WorkGroupSize.x) - FilterLength; + int max_sample = DecimationRate * int((gl_WorkGroupID.x + 1) * gl_WorkGroupSize.x); + + in_offset += min_sample; + int total_samples = max_sample - min_sample; + int samples_per_thread = total_samples / thread_count; + int leftover_count = total_samples % thread_count; + int thread_first_index = samples_per_thread * thread_index + min(thread_index, leftover_count); + int thread_last_index = thread_first_index + samples_per_thread + int(thread_index < leftover_count); const float scale = bool(ComplexFilter) ? 1 : sqrt(2); - - for (int j = max(0, index - FilterLength); j < min(index, a_length); j++) { - vec2 iq = sample_rf(in_offset + j); - FILTER_TYPE h = filter_coefficients[index - j]; - #if Demodulate - result += scale * apply_filter(rotate_iq(iq * vec2(1, -1), -j), h); - #else - result += apply_filter(iq, h); - #endif + for (int i = thread_first_index; i <= thread_last_index; i++) { + SAMPLE_TYPE valid = SAMPLE_TYPE(i + min_sample >= 0); + #if Demodulate + local_samples[i] = scale * valid * rotate_iq(sample_rf(in_offset + i) * vec2(1, -1), -i); + #else + local_samples[i] = valid * sample_rf(in_offset + i); + #endif + } + } + barrier(); + + if (out_sample < SampleCount / DecimationRate) { + SAMPLE_TYPE result = SAMPLE_TYPE(0); + int offset = DecimationRate * thread_index; + for (int j = 0; j < FilterLength; j++) { + result += apply_filter(local_samples[offset + j], + filter_coefficients[FilterLength - 1 - j]); } - out_data[out_offset] = RESULT_TYPE_CAST(result); } }