From fd10688a4d5b9b4710e19121ac16cb0e3675eccf Mon Sep 17 00:00:00 2001 From: Donovan Hutchence Date: Mon, 9 Dec 2024 11:44:46 +0000 Subject: [PATCH] Further gsplat shader refinements (#7185) --- .../examples/loaders/gsplat-many.shader.vert | 28 ++-- src/scene/gsplat/gsplat-compressed.js | 5 +- src/scene/gsplat/gsplat-instance.js | 5 +- src/scene/gsplat/gsplat.js | 5 +- src/scene/shader-lib/chunks/chunks.js | 10 +- .../shader-lib/chunks/gsplat/vert/gsplat.js | 27 ++-- .../chunks/gsplat/vert/gsplatCenter.js | 27 ++++ .../chunks/gsplat/vert/gsplatColor.js | 4 +- .../chunks/gsplat/vert/gsplatCommon.js | 140 +++--------------- .../gsplat/vert/gsplatCompressedData.js | 18 +-- .../chunks/gsplat/vert/gsplatCompressedSH.js | 8 +- .../chunks/gsplat/vert/gsplatCorner.js | 63 ++++++++ .../chunks/gsplat/vert/gsplatData.js | 13 +- .../chunks/gsplat/{ => vert}/gsplatOutput.js | 0 .../shader-lib/chunks/gsplat/vert/gsplatSH.js | 22 +-- .../chunks/gsplat/vert/gsplatSource.js | 33 +++++ 16 files changed, 222 insertions(+), 186 deletions(-) create mode 100644 src/scene/shader-lib/chunks/gsplat/vert/gsplatCenter.js create mode 100644 src/scene/shader-lib/chunks/gsplat/vert/gsplatCorner.js rename src/scene/shader-lib/chunks/gsplat/{ => vert}/gsplatOutput.js (100%) create mode 100644 src/scene/shader-lib/chunks/gsplat/vert/gsplatSource.js diff --git a/examples/src/examples/loaders/gsplat-many.shader.vert b/examples/src/examples/loaders/gsplat-many.shader.vert index 22f66033d15..4ff48299619 100644 --- a/examples/src/examples/loaders/gsplat-many.shader.vert +++ b/examples/src/examples/loaders/gsplat-many.shader.vert @@ -40,37 +40,41 @@ vec4 animateColor(float height, vec4 clr) { void main(void) { // read gaussian center - SplatState state; - if (!initState(state)) { + SplatSource source; + if (!initSource(source)) { gl_Position = discardVec; return; } - vec3 center = animatePosition(readCenter(state)); + vec3 centerPos = animatePosition(readCenter(source)); + + SplatCenter center; + initCenter(source, centerPos, center); // project center to screen space - ProjectedState projState; - if (!projectCenter(state, center, projState)) { + SplatCorner corner; + if (!initCorner(source, center, corner)) { gl_Position = discardVec; return; } // read color - vec4 clr = readColor(state); + vec4 clr = readColor(source); // evaluate spherical harmonics #if SH_BANDS > 0 - clr.xyz = max(clr.xyz + evalSH(state, projState), 0.0); + vec3 dir = normalize(center.view * mat3(center.modelView)); + clr.xyz += evalSH(state, dir); #endif - clr = animateColor(center.y, clr); + clr = animateColor(centerPos.y, clr); - applyClipping(projState, clr.w); + clipCorner(corner, clr.w); // write output - gl_Position = projState.cornerProj; - gaussianUV = projState.cornerUV; - gaussianColor = vec4(prepareOutputFromGamma(clr.xyz), clr.w); + gl_Position = center.proj + vec4(corner.offset, 0.0, 0.0); + gaussianUV = corner.uv; + gaussianColor = vec4(prepareOutputFromGamma(max(clr.xyz, 0.0)), clr.w); #ifndef DITHER_NONE id = float(state.id); diff --git a/src/scene/gsplat/gsplat-compressed.js b/src/scene/gsplat/gsplat-compressed.js index 66d398274f9..402f88a0035 100644 --- a/src/scene/gsplat/gsplat-compressed.js +++ b/src/scene/gsplat/gsplat-compressed.js @@ -27,6 +27,8 @@ class GSplatCompressed { numSplats; + numSplatsVisible; + /** @type {BoundingBox} */ aabb; @@ -57,6 +59,7 @@ class GSplatCompressed { this.device = device; this.numSplats = numSplats; + this.numVisibleSplats = numSplats; // initialize aabb this.aabb = new BoundingBox(); @@ -147,7 +150,7 @@ class GSplatCompressed { result.setDefine('GSPLAT_COMPRESSED_DATA', true); result.setParameter('packedTexture', this.packedTexture); result.setParameter('chunkTexture', this.chunkTexture); - result.setParameter('tex_params', new Float32Array([this.numSplats, this.packedTexture.width, this.chunkTexture.width / 5])); + result.setParameter('numSplats', this.numSplatsVisible); if (this.shTexture0) { result.setDefine('SH_BANDS', 3); result.setParameter('shTexture0', this.shTexture0); diff --git a/src/scene/gsplat/gsplat-instance.js b/src/scene/gsplat/gsplat-instance.js index 0291daf2eeb..fcaa67b4078 100644 --- a/src/scene/gsplat/gsplat-instance.js +++ b/src/scene/gsplat/gsplat-instance.js @@ -143,10 +143,7 @@ class GSplatInstance { this.meshInstance.instancingCount = Math.ceil(count / splatInstanceSize); // update splat count on the material - const tex_params = this.material.getParameter('tex_params'); - if (tex_params?.data) { - tex_params.data[0] = count; - } + this.material.setParameter('numSplats', count); }); } } diff --git a/src/scene/gsplat/gsplat.js b/src/scene/gsplat/gsplat.js index e3140eb2072..d9673f276c1 100644 --- a/src/scene/gsplat/gsplat.js +++ b/src/scene/gsplat/gsplat.js @@ -31,6 +31,8 @@ class GSplat { numSplats; + numSplatsVisible; + /** @type {Float32Array} */ centers; @@ -70,6 +72,7 @@ class GSplat { this.device = device; this.numSplats = numSplats; + this.numSplatsVisible = numSplats; this.centers = new Float32Array(gsplatData.numSplats * 3); gsplatData.getCenters(this.centers); @@ -116,7 +119,7 @@ class GSplat { result.setParameter('splatColor', this.colorTexture); result.setParameter('transformA', this.transformATexture); result.setParameter('transformB', this.transformBTexture); - result.setParameter('tex_params', new Float32Array([this.numSplats, this.colorTexture.width])); + result.setParameter('numSplats', this.numSplatsVisible); if (this.hasSH) { result.setDefine('SH_BANDS', 3); result.setParameter('splatSH_1to3', this.sh1to3Texture); diff --git a/src/scene/shader-lib/chunks/chunks.js b/src/scene/shader-lib/chunks/chunks.js index e72977780ac..3c5756634dd 100644 --- a/src/scene/shader-lib/chunks/chunks.js +++ b/src/scene/shader-lib/chunks/chunks.js @@ -56,14 +56,17 @@ import gamma2_2PS from './common/frag/gamma2_2.js'; import gles3PS from '../../../platform/graphics/shader-chunks/frag/gles3.js'; import gles3VS from '../../../platform/graphics/shader-chunks/vert/gles3.js'; import glossPS from './standard/frag/gloss.js'; +import gsplatCenterVS from './gsplat/vert/gsplatCenter.js'; import gsplatColorVS from './gsplat/vert/gsplatColor.js'; import gsplatCommonVS from './gsplat/vert/gsplatCommon.js'; import gsplatCompressedDataVS from './gsplat/vert/gsplatCompressedData.js'; import gsplatCompressedSHVS from './gsplat/vert/gsplatCompressedSH.js'; +import gsplatCornerVS from './gsplat/vert/gsplatCorner.js'; import gsplatDataVS from './gsplat/vert/gsplatData.js'; -import gsplatOutputPS from './gsplat/gsplatOutput.js'; +import gsplatOutputVS from './gsplat/vert/gsplatOutput.js'; import gsplatPS from './gsplat/frag/gsplat.js'; import gsplatSHVS from './gsplat/vert/gsplatSH.js'; +import gsplatSourceVS from './gsplat/vert/gsplatSource.js'; import gsplatVS from './gsplat/vert/gsplat.js'; import iridescenceDiffractionPS from './lit/frag/iridescenceDiffraction.js'; import iridescencePS from './standard/frag/iridescence.js'; @@ -266,14 +269,17 @@ const shaderChunks = { gles3PS, gles3VS, glossPS, + gsplatCenterVS, + gsplatCornerVS, gsplatColorVS, gsplatCommonVS, gsplatCompressedDataVS, gsplatCompressedSHVS, gsplatDataVS, - gsplatOutputPS, + gsplatOutputVS, gsplatPS, gsplatSHVS, + gsplatSourceVS, gsplatVS, iridescenceDiffractionPS, iridescencePS, diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplat.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplat.js index c17ebabac92..02e20293135 100644 --- a/src/scene/shader-lib/chunks/gsplat/vert/gsplat.js +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplat.js @@ -12,38 +12,43 @@ mediump vec4 discardVec = vec4(0.0, 0.0, 2.0, 1.0); void main(void) { // read gaussian details - SplatState state; - if (!initState(state)) { + SplatSource source; + if (!initSource(source)) { gl_Position = discardVec; return; } - vec3 center = readCenter(state); + vec3 modelCenter = readCenter(source); + + SplatCenter center; + initCenter(source, modelCenter, center); // project center to screen space - ProjectedState projState; - if (!projectCenter(state, center, projState)) { + SplatCorner corner; + if (!initCorner(source, center, corner)) { gl_Position = discardVec; return; } // read color - vec4 clr = readColor(state); + vec4 clr = readColor(source); // evaluate spherical harmonics #if SH_BANDS > 0 - clr.xyz += evalSH(state, projState); + // calculate the model-space view direction + vec3 dir = normalize(center.view * mat3(center.modelView)); + clr.xyz += evalSH(source, dir); #endif - applyClipping(projState, clr.w); + clipCorner(corner, clr.w); // write output - gl_Position = projState.cornerProj; - gaussianUV = projState.cornerUV; + gl_Position = center.proj + vec4(corner.offset, 0, 0); + gaussianUV = corner.uv; gaussianColor = vec4(prepareOutputFromGamma(max(clr.xyz, 0.0)), clr.w); #ifndef DITHER_NONE - id = float(state.id); + id = float(source.id); #endif } `; diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCenter.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCenter.js new file mode 100644 index 00000000000..bf20c944e7d --- /dev/null +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCenter.js @@ -0,0 +1,27 @@ +export default /* glsl */` +uniform mat4 matrix_model; +uniform mat4 matrix_view; +uniform mat4 matrix_projection; + +// project the model space gaussian center to view and clip space +bool initCenter(SplatSource source, vec3 modelCenter, out SplatCenter center) { + mat4 modelView = matrix_view * matrix_model; + vec4 centerView = modelView * vec4(modelCenter, 1.0); + + // early out if splat is behind the camear + if (centerView.z > 0.0) { + return false; + } + + vec4 centerProj = matrix_projection * centerView; + + // ensure gaussians are not clipped by camera near and far + centerProj.z = clamp(centerProj.z, -abs(centerProj.w), abs(centerProj.w)); + + center.view = centerView.xyz / centerView.w; + center.proj = centerProj; + center.projMat00 = matrix_projection[0][0]; + center.modelView = modelView; + return true; +} +`; diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatColor.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatColor.js index 942780bbe49..869b3016c19 100644 --- a/src/scene/shader-lib/chunks/gsplat/vert/gsplatColor.js +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatColor.js @@ -2,8 +2,8 @@ export default /* glsl */` uniform mediump sampler2D splatColor; -vec4 readColor(in SplatState state) { - return texelFetch(splatColor, state.uv, 0); +vec4 readColor(in SplatSource source) { + return texelFetch(splatColor, source.uv, 0); } `; diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCommon.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCommon.js index 2647ecb39d2..a6c13bb1f0a 100644 --- a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCommon.js +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCommon.js @@ -1,20 +1,25 @@ export default /* glsl */` -struct SplatState { +// stores the source UV and order of the splat +struct SplatSource { uint order; // render order uint id; // splat id ivec2 uv; // splat uv vec2 cornerUV; // corner coordinates for this vertex of the gaussian (-1, -1)..(1, 1) }; -struct ProjectedState { +// stores the camera and clip space position of the gaussian center +struct SplatCenter { + vec3 view; // center in view space + vec4 proj; // center in clip space mat4 modelView; // model-view matrix - vec3 centerCam; // center in camera space - vec4 centerProj; // center in clip space + float projMat00; // elememt [0][0] of the projection matrix +}; - vec2 cornerOffset; // corner offset in clip space - vec4 cornerProj; // corner position in clip space - vec2 cornerUV; // corner uv +// stores the offset from center for the current gaussian +struct SplatCorner { + vec2 offset; // corner offset from center in clip space + vec2 uv; // corner uv }; #if SH_BANDS > 0 @@ -36,115 +41,17 @@ struct ProjectedState { #include "gsplatSHVS" #endif -#include "gsplatOutputPS" - -uniform mat4 matrix_model; -uniform mat4 matrix_view; -uniform mat4 matrix_projection; -uniform vec2 viewport; // viewport dimensions -uniform vec4 camera_params; // 1 / far, far, near, isOrtho - -// initialize the splat state structure -bool initState(out SplatState state) { - // calculate splat order - state.order = vertex_id_attrib + uint(vertex_position.z); - - // return if out of range (since the last block of splats may be partially full) - if (state.order >= tex_params.x) { - return false; - } - - ivec2 orderUV = ivec2(state.order % tex_params.y, state.order / tex_params.y); - - // read splat id - state.id = texelFetch(splatOrder, orderUV, 0).r; - - // map id to uv - state.uv = ivec2(state.id % tex_params.y, state.id / tex_params.y); +#include "gsplatSourceVS" +#include "gsplatCenterVS" +#include "gsplatCornerVS" +#include "gsplatOutputVS" - // get the corner - state.cornerUV = vertex_position.xy; - - return true; -} - -// calculate 2d covariance vectors -bool projectCenter(SplatState state, vec3 center, out ProjectedState projState) { - // project center to screen space - mat4 model_view = matrix_view * matrix_model; - vec4 centerCam = model_view * vec4(center, 1.0); - vec4 centerProj = matrix_projection * centerCam; - if (centerProj.z < -centerProj.w) { - return false; - } - - // get covariance - vec3 covA, covB; - readCovariance(state, covA, covB); - - mat3 Vrk = mat3( - covA.x, covA.y, covA.z, - covA.y, covB.x, covB.y, - covA.z, covB.y, covB.z - ); - - float focal = viewport.x * matrix_projection[0][0]; - - vec3 v = camera_params.w == 1.0 ? vec3(0.0, 0.0, 1.0) : centerCam.xyz; - float J1 = focal / v.z; - vec2 J2 = -J1 / v.z * v.xy; - mat3 J = mat3( - J1, 0.0, J2.x, - 0.0, J1, J2.y, - 0.0, 0.0, 0.0 - ); - - mat3 W = transpose(mat3(model_view)); - mat3 T = W * J; - mat3 cov = transpose(T) * Vrk * T; - - float diagonal1 = cov[0][0] + 0.3; - float offDiagonal = cov[0][1]; - float diagonal2 = cov[1][1] + 0.3; - - float mid = 0.5 * (diagonal1 + diagonal2); - float radius = length(vec2((diagonal1 - diagonal2) / 2.0, offDiagonal)); - float lambda1 = mid + radius; - float lambda2 = max(mid - radius, 0.1); - - float l1 = 2.0 * min(sqrt(2.0 * lambda1), 1024.0); - float l2 = 2.0 * min(sqrt(2.0 * lambda2), 1024.0); - - // early-out gaussians smaller than 2 pixels - if (l1 < 2.0 && l2 < 2.0) { - return false; - } - - // perform clipping test against x/y - if (any(greaterThan(abs(centerProj.xy) - vec2(l1, l2) / viewport * centerProj.w, centerProj.ww))) { - return false; - } - - vec2 diagonalVector = normalize(vec2(offDiagonal, lambda1 - diagonal1)); - vec2 v1 = l1 * diagonalVector; - vec2 v2 = l2 * vec2(diagonalVector.y, -diagonalVector.x); - - projState.modelView = model_view; - projState.centerCam = centerCam.xyz; - projState.centerProj = centerProj; - projState.cornerOffset = (state.cornerUV.x * v1 + state.cornerUV.y * v2) / viewport * centerProj.w; - projState.cornerProj = centerProj + vec4(projState.cornerOffset, 0.0, 0.0); - projState.cornerUV = state.cornerUV; - - return true; -} - -// modify the projected gaussian so it excludes regions with alpha +// modify the gaussian corner so it excludes gaussian regions with alpha // less than 1/255 -void applyClipping(inout ProjectedState projState, float alpha) { +void clipCorner(inout SplatCorner corner, float alpha) { float clip = min(1.0, sqrt(-log(1.0 / 255.0 / alpha)) / 2.0); - projState.cornerProj.xy -= projState.cornerOffset * (1.0 - clip); - projState.cornerUV *= clip; + corner.offset *= clip; + corner.uv *= clip; } // spherical Harmonics @@ -172,15 +79,12 @@ void applyClipping(inout ProjectedState projState, float alpha) { #endif // see https://github.com/graphdeco-inria/gaussian-splatting/blob/main/utils/sh_utils.py -vec3 evalSH(in SplatState state, in ProjectedState projState) { +vec3 evalSH(in SplatSource source, in vec3 dir) { // read sh coefficients vec3 sh[SH_COEFFS]; float scale; - readSHData(state, sh, scale); - - // calculate the model-space view direction - vec3 dir = normalize(projState.centerCam * mat3(projState.modelView)); + readSHData(source, sh, scale); float x = dir.x; float y = dir.y; diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedData.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedData.js index e9af68b704d..e3149bf5269 100644 --- a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedData.js +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedData.js @@ -1,9 +1,4 @@ export default /* glsl */` -attribute vec3 vertex_position; // xy: cornerUV, z: render order offset -attribute uint vertex_id_attrib; // render order base - -uniform uvec3 tex_params; // num splats, packed width, chunked width -uniform highp usampler2D splatOrder; uniform highp usampler2D packedTexture; uniform highp sampler2D chunkTexture; @@ -66,9 +61,10 @@ mat3 quatToMat3(vec4 R) { } // read center -vec3 readCenter(SplatState state) { - uint chunkId = state.id / 256u; - ivec2 chunkUV = ivec2((chunkId % tex_params.z) * 5u, chunkId / tex_params.z); +vec3 readCenter(SplatSource source) { + uint w = uint(textureSize(chunkTexture, 0).x) / 5; + uint chunkId = source.id / 256u; + ivec2 chunkUV = ivec2((chunkId % w) * 5u, chunkId / w); // read chunk and packed compressed data chunkDataA = texelFetch(chunkTexture, chunkUV, 0); @@ -76,12 +72,12 @@ vec3 readCenter(SplatState state) { chunkDataC = texelFetch(chunkTexture, chunkUV + ivec2(2, 0), 0); chunkDataD = texelFetch(chunkTexture, chunkUV + ivec2(3, 0), 0); chunkDataE = texelFetch(chunkTexture, chunkUV + ivec2(4, 0), 0); - packedData = texelFetch(packedTexture, state.uv, 0); + packedData = texelFetch(packedTexture, source.uv, 0); return mix(chunkDataA.xyz, vec3(chunkDataA.w, chunkDataB.xy), unpack111011(packedData.x)); } -vec4 readColor(in SplatState state) { +vec4 readColor(in SplatSource source) { vec4 r = unpack8888(packedData.w); return vec4(mix(chunkDataD.xyz, vec3(chunkDataD.w, chunkDataE.xy), r.rgb), r.w); } @@ -95,7 +91,7 @@ vec3 getScale() { } // given a rotation matrix and scale vector, compute 3d covariance A and B -void readCovariance(in SplatState state, out vec3 covA, out vec3 covB) { +void readCovariance(in SplatSource source, out vec3 covA, out vec3 covB) { mat3 rot = quatToMat3(getRotation()); vec3 scale = getScale(); diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedSH.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedSH.js index a63b30ba443..00960780376 100644 --- a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedSH.js +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCompressedSH.js @@ -9,11 +9,11 @@ vec4 unpack8888s(in uint bits) { return vec4((uvec4(bits) >> uvec4(0u, 8u, 16u, 24u)) & 0xffu) * (8.0 / 255.0) - 4.0; } -void readSHData(in SplatState state, out vec3 sh[15], out float scale) { +void readSHData(in SplatSource source, out vec3 sh[15], out float scale) { // read the sh coefficients - uvec4 shData0 = texelFetch(shTexture0, state.uv, 0); - uvec4 shData1 = texelFetch(shTexture1, state.uv, 0); - uvec4 shData2 = texelFetch(shTexture2, state.uv, 0); + uvec4 shData0 = texelFetch(shTexture0, source.uv, 0); + uvec4 shData1 = texelFetch(shTexture1, source.uv, 0); + uvec4 shData2 = texelFetch(shTexture2, source.uv, 0); vec4 r0 = unpack8888s(shData0.x); vec4 r1 = unpack8888s(shData0.y); diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatCorner.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCorner.js new file mode 100644 index 00000000000..c2f98d64545 --- /dev/null +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatCorner.js @@ -0,0 +1,63 @@ +export default /* glsl */` +uniform vec2 viewport; // viewport dimensions +uniform vec4 camera_params; // 1 / far, far, near, isOrtho + +// calculate the clip-space offset from the center for this gaussian +bool initCorner(SplatSource source, SplatCenter center, out SplatCorner corner) { + // get covariance + vec3 covA, covB; + readCovariance(source, covA, covB); + + mat3 Vrk = mat3( + covA.x, covA.y, covA.z, + covA.y, covB.x, covB.y, + covA.z, covB.y, covB.z + ); + + float focal = viewport.x * center.projMat00; + + vec3 v = camera_params.w == 1.0 ? vec3(0.0, 0.0, 1.0) : center.view.xyz; + float J1 = focal / v.z; + vec2 J2 = -J1 / v.z * v.xy; + mat3 J = mat3( + J1, 0.0, J2.x, + 0.0, J1, J2.y, + 0.0, 0.0, 0.0 + ); + + mat3 W = transpose(mat3(center.modelView)); + mat3 T = W * J; + mat3 cov = transpose(T) * Vrk * T; + + float diagonal1 = cov[0][0] + 0.3; + float offDiagonal = cov[0][1]; + float diagonal2 = cov[1][1] + 0.3; + + float mid = 0.5 * (diagonal1 + diagonal2); + float radius = length(vec2((diagonal1 - diagonal2) / 2.0, offDiagonal)); + float lambda1 = mid + radius; + float lambda2 = max(mid - radius, 0.1); + + float l1 = 2.0 * min(sqrt(2.0 * lambda1), 1024.0); + float l2 = 2.0 * min(sqrt(2.0 * lambda2), 1024.0); + + // early-out gaussians smaller than 2 pixels + if (l1 < 2.0 && l2 < 2.0) { + return false; + } + + // perform cull against x/y axes + if (any(greaterThan(abs(center.proj.xy) - vec2(l1, l2) / viewport * center.proj.w, center.proj.ww))) { + return false; + } + + vec2 diagonalVector = normalize(vec2(offDiagonal, lambda1 - diagonal1)); + vec2 v1 = l1 * diagonalVector; + vec2 v2 = l2 * vec2(diagonalVector.y, -diagonalVector.x); + + corner.offset = (source.cornerUV.x * v1 + source.cornerUV.y * v2) / viewport * center.proj.w; + corner.uv = source.cornerUV; + + return true; +} +`; diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatData.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatData.js index d1c530382dc..20ada43f768 100644 --- a/src/scene/shader-lib/chunks/gsplat/vert/gsplatData.js +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatData.js @@ -1,9 +1,4 @@ export default /* glsl */` -attribute vec3 vertex_position; // xy: cornerUV, z: render order offset -attribute uint vertex_id_attrib; // render order base - -uniform uvec2 tex_params; // num splats, texture width -uniform highp usampler2D splatOrder; uniform highp usampler2D transformA; uniform highp sampler2D transformB; @@ -11,16 +6,16 @@ uniform highp sampler2D transformB; uint tAw; // read the model-space center of the gaussian -vec3 readCenter(SplatState state) { +vec3 readCenter(SplatSource source) { // read transform data - uvec4 tA = texelFetch(transformA, state.uv, 0); + uvec4 tA = texelFetch(transformA, source.uv, 0); tAw = tA.w; return uintBitsToFloat(tA.xyz); } // sample covariance vectors -void readCovariance(in SplatState state, out vec3 covA, out vec3 covB) { - vec4 tB = texelFetch(transformB, state.uv, 0); +void readCovariance(in SplatSource source, out vec3 covA, out vec3 covB) { + vec4 tB = texelFetch(transformB, source.uv, 0); vec2 tC = unpackHalf2x16(tAw); covA = tB.xyz; covB = vec3(tC.x, tC.y, tB.w); diff --git a/src/scene/shader-lib/chunks/gsplat/gsplatOutput.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatOutput.js similarity index 100% rename from src/scene/shader-lib/chunks/gsplat/gsplatOutput.js rename to src/scene/shader-lib/chunks/gsplat/vert/gsplatOutput.js diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatSH.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatSH.js index cb81bda88a5..a0be892e5be 100644 --- a/src/scene/shader-lib/chunks/gsplat/vert/gsplatSH.js +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatSH.js @@ -29,28 +29,28 @@ void fetch(in uint t, out vec3 a) { #if SH_BANDS == 1 uniform highp usampler2D splatSH_1to3; - void readSHData(in SplatState state, out vec3 sh[3], out float scale) { - fetchScale(texelFetch(splatSH_1to3, state.uv, 0), scale, sh[0], sh[1], sh[2]); + void readSHData(in SplatSource source, out vec3 sh[3], out float scale) { + fetchScale(texelFetch(splatSH_1to3, source.uv, 0), scale, sh[0], sh[1], sh[2]); } #elif SH_BANDS == 2 uniform highp usampler2D splatSH_1to3; uniform highp usampler2D splatSH_4to7; uniform highp usampler2D splatSH_8to11; - void readSHData(in SplatState state, out vec3 sh[8], out float scale) { - fetchScale(texelFetch(splatSH_1to3, state.uv, 0), scale, sh[0], sh[1], sh[2]); - fetch(texelFetch(splatSH_4to7, state.uv, 0), sh[3], sh[4], sh[5], sh[6]); - fetch(texelFetch(splatSH_8to11, state.uv, 0).x, sh[7]); + void readSHData(in SplatSource source, out vec3 sh[8], out float scale) { + fetchScale(texelFetch(splatSH_1to3, source.uv, 0), scale, sh[0], sh[1], sh[2]); + fetch(texelFetch(splatSH_4to7, source.uv, 0), sh[3], sh[4], sh[5], sh[6]); + fetch(texelFetch(splatSH_8to11, source.uv, 0).x, sh[7]); } #else uniform highp usampler2D splatSH_1to3; uniform highp usampler2D splatSH_4to7; uniform highp usampler2D splatSH_8to11; uniform highp usampler2D splatSH_12to15; - void readSHData(in SplatState state, out vec3 sh[15], out float scale) { - fetchScale(texelFetch(splatSH_1to3, state.uv, 0), scale, sh[0], sh[1], sh[2]); - fetch(texelFetch(splatSH_4to7, state.uv, 0), sh[3], sh[4], sh[5], sh[6]); - fetch(texelFetch(splatSH_8to11, state.uv, 0), sh[7], sh[8], sh[9], sh[10]); - fetch(texelFetch(splatSH_12to15, state.uv, 0), sh[11], sh[12], sh[13], sh[14]); + void readSHData(in SplatSource source, out vec3 sh[15], out float scale) { + fetchScale(texelFetch(splatSH_1to3, source.uv, 0), scale, sh[0], sh[1], sh[2]); + fetch(texelFetch(splatSH_4to7, source.uv, 0), sh[3], sh[4], sh[5], sh[6]); + fetch(texelFetch(splatSH_8to11, source.uv, 0), sh[7], sh[8], sh[9], sh[10]); + fetch(texelFetch(splatSH_12to15, source.uv, 0), sh[11], sh[12], sh[13], sh[14]); } #endif diff --git a/src/scene/shader-lib/chunks/gsplat/vert/gsplatSource.js b/src/scene/shader-lib/chunks/gsplat/vert/gsplatSource.js new file mode 100644 index 00000000000..14bfc5b79e0 --- /dev/null +++ b/src/scene/shader-lib/chunks/gsplat/vert/gsplatSource.js @@ -0,0 +1,33 @@ +export default /* glsl */` +attribute vec3 vertex_position; // xy: cornerUV, z: render order offset +attribute uint vertex_id_attrib; // render order base + +uniform uint numSplats; // total number of splats +uniform highp usampler2D splatOrder; // per-splat index to source gaussian + +// initialize the splat source structure +bool initSource(out SplatSource source) { + uint w = uint(textureSize(splatOrder, 0).x); + + // calculate splat order + source.order = vertex_id_attrib + uint(vertex_position.z); + + // return if out of range (since the last block of splats may be partially full) + if (source.order >= numSplats) { + return false; + } + + ivec2 orderUV = ivec2(source.order % w, source.order / w); + + // read splat id + source.id = texelFetch(splatOrder, orderUV, 0).r; + + // map id to uv + source.uv = ivec2(source.id % w, source.id / w); + + // get the corner + source.cornerUV = vertex_position.xy; + + return true; +} +`;