filter.glsl (3306B)
1 /* See LICENSE for license details. */ 2 /* TODO(rnp): bug: this won't filter RF data correctly */ 3 #define SAMPLE_TYPE vec2 4 #if DataKind == DataKind_Float32 5 #define DATA_TYPE vec2 6 #define RESULT_TYPE_CAST(v) (v) 7 #define SAMPLE_TYPE_CAST(v) (v) 8 #else 9 #define DATA_TYPE uint 10 #define SAMPLE_TYPE_CAST(v) unpackSnorm2x16(v) 11 #if OutputFloats 12 #define OUT_DATA_TYPE vec2 13 #define RESULT_TYPE_CAST(v) (clamp((v), -1.0, 1.0) * 32767.0f) 14 #else 15 #define RESULT_TYPE_CAST(v) packSnorm2x16(v) 16 #endif 17 #endif 18 19 #ifndef OUT_DATA_TYPE 20 #define OUT_DATA_TYPE DATA_TYPE 21 #endif 22 23 #if ComplexFilter 24 #define FILTER_TYPE vec2 25 #define apply_filter(iq, h) complex_mul((iq), (h)) 26 #else 27 #define FILTER_TYPE float 28 #define apply_filter(iq, h) ((iq) * (h)) 29 #endif 30 31 layout(std430, binding = 1) readonly restrict buffer buffer_1 { 32 DATA_TYPE in_data[]; 33 }; 34 35 layout(std430, binding = 2) writeonly restrict buffer buffer_2 { 36 OUT_DATA_TYPE out_data[]; 37 }; 38 39 layout(std430, binding = 3) readonly restrict buffer buffer_3 { 40 FILTER_TYPE filter_coefficients[FilterLength]; 41 }; 42 43 vec2 complex_mul(vec2 a, vec2 b) 44 { 45 mat2 m = mat2(b.x, b.y, -b.y, b.x); 46 vec2 result = m * a; 47 return result; 48 } 49 50 #if Demodulate 51 vec2 rotate_iq(vec2 iq, uint index) 52 { 53 float arg = radians(360) * DemodulationFrequency * index / SamplingFrequency; 54 vec2 result = complex_mul(iq, vec2(cos(arg), -sin(arg))); 55 return result; 56 } 57 #endif 58 59 SAMPLE_TYPE sample_rf(uint index) 60 { 61 SAMPLE_TYPE result = SAMPLE_TYPE_CAST(in_data[index]); 62 return result; 63 } 64 65 shared SAMPLE_TYPE rf[DecimationRate * gl_WorkGroupSize.x + FilterLength - 1]; 66 67 void main() 68 { 69 uint out_sample = gl_GlobalInvocationID.x; 70 uint channel = gl_GlobalInvocationID.y; 71 uint transmit = gl_GlobalInvocationID.z; 72 73 uint in_offset = InputChannelStride * channel + InputTransmitStride * transmit; 74 uint out_offset = OutputChannelStride * channel + 75 OutputTransmitStride * transmit + 76 OutputSampleStride * out_sample; 77 78 uint thread_index = gl_LocalInvocationIndex; 79 uint thread_count = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z; 80 ///////////////////////// 81 // NOTE: sample caching 82 { 83 in_offset += DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x - (FilterLength - 1); 84 85 uint total_samples = rf.length(); 86 uint samples_per_thread = total_samples / thread_count; 87 uint leftover_count = total_samples % thread_count; 88 uint samples_this_thread = samples_per_thread + uint(thread_index < leftover_count); 89 90 const float scale = bool(ComplexFilter) ? 1 : sqrt(2); 91 for (uint i = 0; i < samples_this_thread; i++) { 92 uint index = thread_count * i + thread_index; 93 if (gl_WorkGroupID.x == 0 && index < FilterLength - 1) { 94 rf[index] = SAMPLE_TYPE(0); 95 } else { 96 #if Demodulate 97 rf[index] = scale * rotate_iq(sample_rf(in_offset + index) * vec2(1, -1), index); 98 #else 99 rf[index] = sample_rf(in_offset + index); 100 #endif 101 } 102 } 103 } 104 barrier(); 105 106 if (out_sample < SampleCount / DecimationRate) { 107 SAMPLE_TYPE result = SAMPLE_TYPE(0); 108 uint offset = DecimationRate * thread_index; 109 for (uint j = 0; j < FilterLength; j++) 110 result += apply_filter(rf[offset + j], filter_coefficients[j]); 111 out_data[out_offset] = RESULT_TYPE_CAST(result); 112 } 113 }