diffusion_models.models.openai_unet.QKVAttention

class diffusion_models.models.openai_unet.QKVAttention(*args, **kwargs)[source]

Bases: Module

A module which performs QKV attention.

__init__(*args, **kwargs)

Methods

__init__(*args, **kwargs)

count_flops(model, _x, y)

A counter for the thop package to count the operations in an attention operation.

forward(qkv)

Apply QKV attention.

static count_flops(model, _x, y)[source]

A counter for the thop package to count the operations in an attention operation.

Meant to be used like:

macs, params = thop.profile(

model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops},

)

forward(qkv)[source]

Apply QKV attention.

Parameters:

qkv – an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.

Returns:

an [N x C x T] tensor after attention.