{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyPNQJbT8chmsjYoacF3B3Ep",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"source": [
"# noise2music-inspired automatic music captioning\n",
"\n",
"In [noise2music](https://google-research.github.io/noise2music/), the training dataset is created by pseudo-labeling a vast collection of unlabeled music audio using two advanced deep learning models. A large language model generates a diverse set of general music-related descriptive sentences to serve as potential captions. These captions are then matched to individual music clips through zero-shot classification, leveraging a pre-trained joint embedding model designed for music and text.\n",
"\n",
"So being curious, let's try the following:\n",
"\n",
"1. Generate a lot of music descriptions with a Meta Llama 3.2 LLM.\n",
"1. Embed the generated music descriptions with a LAION CLAP text encoder.\n",
"1. Index the text embeddings for nearest neighbor retrieval with FAISS.\n",
"1. Use the corresponding audio encoder to embed an audio example.\n",
"1. Use the audio embedding as search query for retrieving text embeddings.\n",
"\n",
"**Could this simple method produce reasonable audio captions?**"
],
"metadata": {
"id": "EVsGy9kxKMn-"
}
},
{
"cell_type": "code",
"source": [
"pip install -q datasets faiss-cpu"
],
"metadata": {
"id": "vhPejzFwbCo1"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import faiss\n",
"import transformers\n",
"import datasets\n",
"import polars as pl\n",
"import librosa as lr\n",
"import numpy as np\n",
"import tqdm.auto as tqdm\n",
"import seaborn as sns\n",
"\n",
"# Configure plotting.\n",
"pl.Config.set_fmt_str_lengths(256)\n",
"sns.set_style(\"ticks\")\n",
"sns.set_theme(\"notebook\")\n",
"\n",
"# Download some example audio files.\n",
"dataset = datasets.load_dataset(\"marsyas/gtzan\", trust_remote_code=True)\n",
"\n",
"# Download a pretrained text generation model.\n",
"text_generator = transformers.pipeline(\n",
" task=\"text-generation\",\n",
" model=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
")\n",
"\n",
"# Download a pretrained CLAP model.\n",
"clap_model = transformers.ClapModel.from_pretrained(\"laion/larger_clap_general\")\n",
"clap_processor = transformers.ClapProcessor.from_pretrained(\"laion/larger_clap_general\")"
],
"metadata": {
"id": "Ktlv3YEs9agI"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 147
},
"id": "yrF4CdDk8irv",
"outputId": "a634a9f2-66cf-49de-fdf7-0652998e5f00"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"shape: (1, 1)\n",
"┌──────────────────────────────────────────────────────────────────────────────────────────────────┐\n",
"│ generated_text │\n",
"│ --- │\n",
"│ str │\n",
"╞══════════════════════════════════════════════════════════════════════════════════════════════════╡\n",
"│ The piece features a haunting piano melody, punctuated by sparse strings and a subtle, pulsing │\n",
"│ bass line that creates an eerie, atmospheric backdrop for a whispered vocal │\n",
"└──────────────────────────────────────────────────────────────────────────────────────────────────┘"
],
"text/html": [
"
\n",
"shape: (1, 1)
generated_text
str
"The piece features a haunting piano melody, punctuated by sparse strings and a subtle, pulsing bass line that creates an eerie, atmospheric backdrop for a whispered vocal"
"
]
},
"metadata": {},
"execution_count": 3
}
],
"source": [
"# Generate a lot of music descriptions.\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a music reviewer who is specific, brief and accurate.\"},\n",
" {\"role\": \"user\", \"content\": \"Imagine any random piece of music and describe how it sounds in one sentence without mentioning the name or artist.\"},\n",
"]\n",
"descriptions = text_generator(\n",
" messages,\n",
" num_return_sequences=1000,\n",
" return_full_text=False,\n",
" do_sample=True,\n",
" num_beams=1,\n",
" max_new_tokens=32,\n",
")\n",
"descriptions = pl.DataFrame(descriptions)\n",
"\n",
"# Save the descriptions to file.\n",
"descriptions.write_parquet(\"music_descriptions.parquet\")\n",
"descriptions.sample()"
]
},
{
"cell_type": "code",
"source": [
"num_dimensions = clap_model.config.projection_dim\n",
"index = faiss.IndexFlatL2(num_dimensions)"
],
"metadata": {
"id": "WFJIhAWJ0unl"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Tokenize text descriptions.\n",
"inputs = clap_processor(text=descriptions[\"generated_text\"].to_list(), return_tensors=\"pt\", padding=True)\n",
"\n",
"# Populate local vector database.\n",
"batch_size = 8\n",
"for i in tqdm.trange(0, len(inputs[\"input_ids\"]), batch_size, desc=\"Indexing descriptions\"):\n",
" input_ids = inputs[\"input_ids\"][i:i + batch_size]\n",
" attention_mask = inputs[\"attention_mask\"][i:i + batch_size]\n",
"\n",
" # Embed the tokens.\n",
" text_embeddings = clap_model.get_text_features(input_ids, attention_mask)\n",
"\n",
" # Add embeddings to the index.\n",
" index.add(text_embeddings.numpy(force=True))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 49,
"referenced_widgets": [
"8169284539774f2ca1d46492faee3b4a",
"a8353dd195154f95b98066a0bcb8bb2b",
"b68893e982f8455e8d4c4bcf4062bfde",
"8e6add65bea04b3fbae40b7febb47a77",
"9d007a2cd86b45709fe2273211bd89b7",
"57454396c13042cea979f0b6b8003dda",
"6387af3665084af08b3c3e09662a0ff4",
"af1b1b90c1024a47abed21819ac2fe41",
"9429673b334e4fc7a43774b9d0dfc9c4",
"6658519e4ceb45adb7991cf5f2b2897b",
"56923a5e99b44bc4834a750a97651e6b"
]
},
"id": "-vzUKMwoIAyq",
"outputId": "f8f6c1f1-fe9c-4068-8d5c-ec21ba91a461"
},
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Indexing descriptions: 0%| | 0/125 [00:00, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "8169284539774f2ca1d46492faee3b4a"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# Load an example audio file.\n",
"audio_file = lr.example(\"trumpet\")\n",
"waveform, samplerate = lr.load(audio_file, sr=clap_processor.feature_extractor.sampling_rate)\n",
"\n",
"# Compute audio embedding.\n",
"inputs = clap_processor(audios=waveform, return_tensors=\"pt\", sampling_rate=clap_processor.feature_extractor.sampling_rate)\n",
"audio_embedding = clap_model.get_audio_features(**inputs)\n",
"audio_embedding.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MHxCjhxEHKmd",
"outputId": "d6bc8d0a-8873-4fb7-da7d-decaa67fa096"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1, 512])"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"# Search for similar embeddings.\n",
"max_results = 1000\n",
"similarities, neighbor_ids = index.search(audio_embedding.numpy(force=True), k=max_results)\n",
"similarities.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2sjJQ04ROh8i",
"outputId": "4c3507bf-4875-4e87-9bd7-232e9ddd703a"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1, 1000)"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"# Lookup underlying text descriptions.\n",
"matches = descriptions[neighbor_ids[0][:max_results]].with_columns(pl.Series(\"scores\", similarities[0][:max_results]))\n",
"matches.top_k(5, by=\"scores\")"
],
"metadata": {
"id": "0G7HvVXXOv3g",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 255
},
"outputId": "bf5b9424-d5d6-4d01-9a17-bceca3904566"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"shape: (5, 2)\n",
"┌───────────────────────────────────────────────────────────────────────────────────────┬──────────┐\n",
"│ generated_text ┆ scores │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f32 │\n",
"╞═══════════════════════════════════════════════════════════════════════════════════════╪══════════╡\n",
"│ This piece of music features a haunting, atmospheric arrangement of eerie whispers ┆ 2.301279 │\n",
"│ and dissonant harmonies, punctuated by sudden, percussive bursts of sound that ┆ │\n",
"│ The piece is a haunting, atmospheric soundscape of pulsing synthesizers, eerie ┆ 2.252909 │\n",
"│ whispers, and a steady, pulsing heartbeat, evoking a sense of fore ┆ │\n",
"│ This piece features a haunting, atmospheric soundscape of whispers and creaks, ┆ 2.249522 │\n",
"│ punctuated by a sparse, pulsing rhythm that gradually builds into a crescendo ┆ │\n",
"│ The piece features a mesmerizing blend of eerie whispers, pulsating electronic beats, ┆ 2.237405 │\n",
"│ and haunting vocal harmonies that create an unsettling atmosphere, gradually building ┆ │\n",
"│ towards a dis ┆ │\n",
"│ This 5-minute composition features a gradual build-up of atmospheric textures, with ┆ 2.235468 │\n",
"│ layers of haunting piano and whispery vocals gradually giving way to a driving, ┆ │\n",
"│ pulsing ┆ │\n",
"└───────────────────────────────────────────────────────────────────────────────────────┴──────────┘"
],
"text/html": [
"
\n",
"shape: (5, 2)
generated_text
scores
str
f32
"This piece of music features a haunting, atmospheric arrangement of eerie whispers and dissonant harmonies, punctuated by sudden, percussive bursts of sound that"
2.301279
"The piece is a haunting, atmospheric soundscape of pulsing synthesizers, eerie whispers, and a steady, pulsing heartbeat, evoking a sense of fore"
2.252909
"This piece features a haunting, atmospheric soundscape of whispers and creaks, punctuated by a sparse, pulsing rhythm that gradually builds into a crescendo"
2.249522
"The piece features a mesmerizing blend of eerie whispers, pulsating electronic beats, and haunting vocal harmonies that create an unsettling atmosphere, gradually building towards a dis"
2.237405
"This 5-minute composition features a gradual build-up of atmospheric textures, with layers of haunting piano and whispery vocals gradually giving way to a driving, pulsing"