Skip to content

Commit

Permalink
Adds argument for number of input conv layers to use in PCT Encoder.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605613751
  • Loading branch information
Scenic Authors committed Feb 14, 2024
1 parent 3b7f0a2 commit de04c5f
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions scenic/projects/pointcloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def __call__(self,
use_bias=True)(
inputs)

if self.attention_fn_configs is None or self.attention_fn_configs[
'attention_kind'] == 'regular':
if (
self.attention_fn_configs is None
or self.attention_fn_configs['attention_kind'] == 'regular'
):
attention = jnp.einsum('...MC,...NC->...MN', input_q, input_k)
if mask is not None:
mask = nn.make_attention_mask(mask, mask)
Expand All @@ -118,8 +120,9 @@ def __call__(self,
key,
value,
kernel_config=self.attention_fn_configs['performer'])
elif self.attention_fn_configs['performer'][
'masking_type'] == 'fftmasked':
elif (
self.attention_fn_configs['performer']['masking_type'] == 'fftmasked'
):
toeplitz_params = self.param('toeplitz_params', zeros,
(query.shape[-2], 2 * query.shape[-3] - 1))
output = performer.masked_performer_dot_product_attention(
Expand All @@ -128,8 +131,10 @@ def __call__(self,
value,
toeplitz_params=toeplitz_params,
kernel_config=self.attention_fn_configs['performer'])
elif self.attention_fn_configs['performer'][
'masking_type'] == 'sharpmasked':
elif (
self.attention_fn_configs['performer']['masking_type']
== 'sharpmasked'
):
toeplitz_params = self.param(
'toeplitz_params', zeros,
(query.shape[-2], 5 * NUM_FT_PARAMS_PER_HEAD))
Expand Down Expand Up @@ -160,6 +165,7 @@ class PointCloudTransformerEncoder(nn.Module):
kernel_size: int | None = 1
encoder_feature_dim: int | None = 1024
num_attention_layers: int | None = 4
num_pre_conv_layers: int = 2
num_heads: int | None = 1
attention_fn_configs: dict[Any, Any] | None = None
use_attention_masking: bool | None = False
Expand All @@ -176,18 +182,14 @@ def __call__(
train: bool = False,
debug: bool = False,
):
output = nn.Conv(
self.feature_dim,
kernel_size=(self.kernel_size,),
use_bias=True,
)(inputs)
output = nn.LayerNorm(reduction_axes=-2)(output, mask=mask)
output = nn.Conv(
self.feature_dim,
kernel_size=(self.kernel_size,),
use_bias=True,
)(output)
output = nn.LayerNorm(reduction_axes=-2)(output, mask=mask)
output = inputs
for _ in range(self.num_pre_conv_layers):
output = nn.Conv(
self.feature_dim,
kernel_size=(self.kernel_size,),
use_bias=True,
)(output)
output = nn.LayerNorm(reduction_axes=-2)(output, mask=mask)

# Self-attention blocks, input_shape= [B, N, D]
attention_outputs = []
Expand Down

0 comments on commit de04c5f

Please sign in to comment.