| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| | from huggingface_hub import PyTorchModelHubMixin |
| |
|
| |
|
| | class ResBlock1D(nn.Module): |
| | """ |
| | Residual Block for extracting rhythmic features from audio spectrograms. |
| | Maintains temporal resolution while increasing receptive field. |
| | """ |
| |
|
| | def __init__(self, channels, kernel_size=3, dilation=1): |
| | super().__init__() |
| | padding = (kernel_size - 1) * dilation // 2 |
| | self.conv1 = nn.Conv1d( |
| | channels, channels, kernel_size, padding=padding, dilation=dilation |
| | ) |
| | self.bn1 = nn.BatchNorm1d(channels) |
| | self.conv2 = nn.Conv1d( |
| | channels, channels, kernel_size, padding=padding, dilation=dilation |
| | ) |
| | self.bn2 = nn.BatchNorm1d(channels) |
| |
|
| | def forward(self, x): |
| | res = x |
| | x = F.gelu(self.bn1(self.conv1(x))) |
| | x = self.bn2(self.conv2(x)) |
| | return F.gelu(x + res) |
| |
|
| |
|
| | class GameChartEvaluator(nn.Module, PyTorchModelHubMixin): |
| | def __init__(self, input_dim=80, d_model=128, n_layers=4): |
| | super().__init__() |
| |
|
| | |
| | |
| | |
| | self.input_proj = nn.Conv1d( |
| | input_dim * 2, d_model, kernel_size=3, stride=1, padding=1 |
| | ) |
| |
|
| | |
| | |
| | |
| | self.encoder = nn.Sequential( |
| | ResBlock1D(d_model, kernel_size=3, dilation=1), |
| | ResBlock1D(d_model, kernel_size=3, dilation=2), |
| | ResBlock1D(d_model, kernel_size=3, dilation=4), |
| | ResBlock1D(d_model, kernel_size=3, dilation=8), |
| | |
| | ) |
| |
|
| | |
| | |
| | self.quality_proj = nn.Linear(d_model, 1) |
| |
|
| | |
| | self.raw_severity = nn.Parameter(torch.tensor(0.0)) |
| |
|
| | def forward(self, music_mels, chart_mels): |
| | """ |
| | music_mels: (Batch, 80, Time) |
| | chart_mels: (Batch, 80, Time) |
| | """ |
| | |
| | |
| | x = torch.cat([music_mels, chart_mels], dim=1) |
| |
|
| | |
| | x = F.gelu(self.input_proj(x)) |
| | x = self.encoder(x) |
| |
|
| | |
| | |
| | x = x.permute(0, 2, 1) |
| | local_scores = torch.sigmoid(self.quality_proj(x)) |
| |
|
| | |
| | avg_score = local_scores.mean(dim=1) |
| |
|
| | k = max(1, int(local_scores.size(1) * 0.1)) |
| | min_vals, _ = torch.topk(local_scores, k, dim=1, largest=False) |
| | worst_score = min_vals.mean(dim=1) |
| |
|
| | alpha = torch.sigmoid(self.raw_severity) |
| | final_score = (alpha * worst_score) + ((1 - alpha) * avg_score) |
| |
|
| | return final_score.squeeze(1) |
| |
|
| | def predict_trace(self, music_mels, chart_mels): |
| | """ |
| | Explainability Method: Returns the second-by-second quality curve. |
| | |
| | Returns: |
| | local_scores: (Batch, Time) - The quality score at every timestep. |
| | """ |
| | with torch.no_grad(): |
| | |
| | |
| | x = torch.cat([music_mels, chart_mels], dim=1) |
| |
|
| | |
| | x = F.gelu(self.input_proj(x)) |
| | x = self.encoder(x) |
| |
|
| | |
| | |
| | x = x.permute(0, 2, 1) |
| | local_scores = torch.sigmoid(self.quality_proj(x)) |
| | return local_scores.squeeze(2) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | from torchinfo import summary |
| |
|
| | model = GameChartEvaluator() |
| | print( |
| | f"Model initialized. Learnable Severity: {torch.sigmoid(model.raw_severity).item():.2f}" |
| | ) |
| |
|
| | |
| | m = torch.randn(2, 80, 1000) |
| | c = torch.randn(2, 80, 1000) |
| |
|
| | output = model(m, c) |
| | print(f"Output shape: {output.shape}") |
| | print(f"Scores: {output}") |
| |
|
| | |
| | trace = model.predict_trace(m, c) |
| | print( |
| | f"Trace shape: {trace.shape}" |
| | ) |
| |
|
| | summary(model, input_data=[m, c]) |
| |
|