Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question of technical implemention details on Z^max (Equation 3) #60

Open
Justin62628 opened this issue Aug 27, 2023 · 0 comments
Open

Comments

@Justin62628
Copy link

Hi Simon,

I'm trying to re-produce your recent paper on splatting-based synthesis for video frame interpolation and it was really nice work that inspires me a lot. But I'm stuck at implementing numerically stable softsplat you mentioned in Section 3, where you said that "warp Z0 to time t as Zmax ... this step is and need not be differentiable ...". I'd be appreciated if you could further clarify the following two questions:

  1. how to implement the necessary "backward" function of torch.autograd.Function to calculate Zmax in training process. I've implemented the following snippet to calculate Zmax and it works well,
class softsplat_zmax_func(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(self, tenIn, tenFlow):
        tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])  # max weight

        if tenIn.is_cuda == True:
            cuda_launch(cuda_kernel('zmax_out', '''
            
                __device__ __forceinline__ float atomicMinFloat(float* addr, float value) {
                    float old;
                    old = !signbit(value) ? __int_as_float(atomicMin((int*)addr, __float_as_int(value))) :
                        __uint_as_float(atomicMax((unsigned int*)addr, __float_as_uint(value)));
                
                    return old;
                }
                
                __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
                    float old;
                    old = !signbit(value) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
                        __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
                
                    return old;
                }
            
                extern "C" __global__ void __launch_bounds__(512) zmax_out(
                    const int n,
                    const {{type}}* __restrict__ tenIn,  // Z input only, B 1 H W
                    const {{type}}* __restrict__ tenFlow,
                    {{type}}* __restrict__ tenOut  // Z max output
                ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
                    const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
                    const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut)                  ) % SIZE_1(tenOut);
                    const int intY = ( intIndex / SIZE_3(tenOut)                                   ) % SIZE_2(tenOut);
                    const int intX = ( intIndex                                                    ) % SIZE_3(tenOut);

                    assert(SIZE_1(tenFlow) == 2);

                    {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
                    {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);

                    if (isfinite(fltX) == false) { return; }
                    if (isfinite(fltY) == false) { return; }

                    {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);

                    int intNorthwestX = (int) (floor(fltX));
                    int intNorthwestY = (int) (floor(fltY));
                    int intNortheastX = intNorthwestX + 1;
                    int intNortheastY = intNorthwestY;
                    int intSouthwestX = intNorthwestX;
                    int intSouthwestY = intNorthwestY + 1;
                    int intSoutheastX = intNorthwestX + 1;
                    int intSoutheastY = intNorthwestY + 1;
                    
                    /*
                    for (int i = intNorthwestX - 1; i < intNorthwestX + 3; i++)
                    {
                        for (int j = intNorthwestY - 1; j < intNorthwestY + 3; j++)
                        {
                            if ((i >= 0) && (i < SIZE_3(tenOut)) && (j >= 0) && (j < SIZE_2(tenOut))) {
                                atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, j, i)], fltIn);
                            }
                        }
                    } 
                    */

                    
                    if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn);
                    }

                    if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn);
                    }

                    if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn);
                    }

                    if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn);
                    }
                    
                } }
            ''', {
                'tenIn': tenIn,
                'tenFlow': tenFlow,
                'tenOut': tenOut
            }))(
                grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
            )

        elif tenIn.is_cuda != True:
            assert (False)

        # end

        self.save_for_backward(tenIn, tenFlow)

        return tenOut

    # end

along with some modification on the softsplat function

...
    elif strMode.split('-')[0] == 'soft':
        tenMetricMax = softsplat_zmax_func.apply(tenMetric, tenFlow)
        tenMetric = torch.exp(tenMetric - tenMetricMax)
        # tenMetric = torch.exp(tenMetric)
        tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
...

it's fine for inference but and I can't figure out how to design the backward function for softsplat_zmax_func since it requires some gradient so as not to mess up the training.

  1. I notice that atomic max of cupy does not support float operation, while I notice you said that "This can be efficiently computed in parallel using an atomic max". Could you please share with us how you handled this?

Thanks in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant