This commit is contained in:
2025-03-07 21:00:42 -05:00
parent bb17c0651e
commit 65380da39a
32 changed files with 7070 additions and 397 deletions

View File

@@ -40,6 +40,10 @@ android {
}
dependencies {
api(fileTree("libs") {
include("*.jar")
})
api(files("libs/opus.aar"))
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.appcompat)
implementation(libs.play.services.wearable)

BIN
mobile/libs/opus.aar Normal file

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

View File

@@ -1,201 +0,0 @@
package com.birdsounds.identify;
import android.app.Activity;
import android.content.Context;
import android.util.Log;
import android.view.View;
import android.widget.Toast;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.net.URL;
import java.net.URLConnection;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
@SuppressWarnings("ResultOfMethodCallIgnored")
public class Downloader {
static final String modelFILE = "modelfx.tflite";
static final String metaModelFILE = "metaModelfx.tflite";
static final String modelURL = "https://raw.githubusercontent.com/woheller69/whoBIRD-TFlite/master/BirdNET_GLOBAL_6K_V2.4_Model_FP16.tflite";
static final String model32URL = "https://raw.githubusercontent.com/woheller69/whoBIRD-TFlite/master/BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite";
static final String metaModelURL = "https://raw.githubusercontent.com/woheller69/whoBIRD-TFlite/master/BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite";
static final String modelMD5 = "b1c981fe261910b473b9b7eec9ebcd4e";
static final String model32MD5 = "6c7c42106e56550fc8563adb31bc120e";
static final String metaModelMD5 ="f1a078ae0f244a1ff5a8f1ccb645c805";
public static boolean checkModels(final Activity activity) {
File modelFile = new File(activity.getDir("filesdir", Context.MODE_PRIVATE) + "/" + modelFILE);
File metaModelFile = new File(activity.getDir("filesdir", Context.MODE_PRIVATE) + "/" + metaModelFILE);
String calcModelMD5 = "";
String calcMetaModelMD5 = "";
if (modelFile.exists()) {
try {
byte[] data = Files.readAllBytes(Paths.get(modelFile.getPath()));
byte[] hash = MessageDigest.getInstance("MD5").digest(data);
calcModelMD5 = new BigInteger(1, hash).toString(16);
} catch (IOException | NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
if (metaModelFile.exists()) {
try {
byte[] data = Files.readAllBytes(Paths.get(metaModelFile.getPath()));
byte[] hash = MessageDigest.getInstance("MD5").digest(data);
calcMetaModelMD5 = new BigInteger(1, hash).toString(16);
} catch (IOException | NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
if (modelFile.exists() && !(calcModelMD5.equals(modelMD5) || calcModelMD5.equals(model32MD5))) modelFile.delete();
if (metaModelFile.exists() && !calcMetaModelMD5.equals(metaModelMD5)) metaModelFile.delete();
return (calcModelMD5.equals(modelMD5) || calcModelMD5.equals(model32MD5)) && calcMetaModelMD5.equals(metaModelMD5);
}
public static void downloadModels(final Activity activity) {
File modelFile = new File(activity.getDir("filesdir", Context.MODE_PRIVATE) + "/" + modelFILE);
Log.d("Heyy","Model file checking");
if (!modelFile.exists()) {
Log.d("whoBIRD", "model file does not exist");
Thread thread = new Thread(() -> {
try {
URL url;
if (false) url = new URL(model32URL);
else url = new URL(modelURL);
Log.d("whoBIRD", "Download model");
URLConnection ucon = url.openConnection();
Log.d("whoBIRD", "i am here");
ucon.setReadTimeout(5000);
ucon.setConnectTimeout(10000);
InputStream is = ucon.getInputStream();
BufferedInputStream inStream = new BufferedInputStream(is, 1024 * 5);
modelFile.createNewFile();
FileOutputStream outStream = new FileOutputStream(modelFile);
byte[] buff = new byte[5 * 1024];
int len;
while ((len = inStream.read(buff)) != -1) {
outStream.write(buff, 0, len);
}
outStream.flush();
outStream.close();
inStream.close();
String calcModelMD5="";
if (modelFile.exists()) {
byte[] data = Files.readAllBytes(Paths.get(modelFile.getPath()));
byte[] hash = MessageDigest.getInstance("MD5").digest(data);
calcModelMD5 = new BigInteger(1, hash).toString(16);
} else {
throw new IOException(); //throw exception if there is no modelFile at this point
}
if (!(calcModelMD5.equals(modelMD5) || calcModelMD5.equals(model32MD5) )){
modelFile.delete();
activity.runOnUiThread(() -> {
Toast.makeText(activity, activity.getResources().getString(R.string.error_download), Toast.LENGTH_SHORT).show();
});
} else {
activity.runOnUiThread(() -> {
});
}
} catch (NoSuchAlgorithmException | IOException i) {
activity.runOnUiThread(() -> Toast.makeText(activity, activity.getResources().getString(R.string.error_download), Toast.LENGTH_SHORT).show());
modelFile.delete();
}
});
thread.start();
try {
thread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else {
Log.d("whoBIRD","model exists");
activity.runOnUiThread(() -> {
});
}
File metaModelFile = new File(activity.getDir("filesdir", Context.MODE_PRIVATE) + "/" + metaModelFILE);
if (!metaModelFile.exists()) {
Log.d("whoBIRD", "meta model file does not exist");
Thread thread = new Thread(() -> {
try {
URL url = new URL(metaModelURL);
Log.d("whoBIRD", "Download meta model");
URLConnection ucon = url.openConnection();
ucon.setReadTimeout(5000);
ucon.setConnectTimeout(10000);
InputStream is = ucon.getInputStream();
BufferedInputStream inStream = new BufferedInputStream(is, 1024 * 5);
metaModelFile.createNewFile();
FileOutputStream outStream = new FileOutputStream(metaModelFile);
byte[] buff = new byte[5 * 1024];
int len;
while ((len = inStream.read(buff)) != -1) {
outStream.write(buff, 0, len);
}
outStream.flush();
outStream.close();
inStream.close();
String calcMetaModelMD5="";
if (metaModelFile.exists()) {
byte[] data = Files.readAllBytes(Paths.get(metaModelFile.getPath()));
byte[] hash = MessageDigest.getInstance("MD5").digest(data);
calcMetaModelMD5 = new BigInteger(1, hash).toString(16);
} else {
throw new IOException(); //throw exception if there is no modelFile at this point
}
if (!calcMetaModelMD5.equals(metaModelMD5)){
metaModelFile.delete();
activity.runOnUiThread(() -> {
Toast.makeText(activity, activity.getResources().getString(R.string.error_download), Toast.LENGTH_SHORT).show();
});
} else {
activity.runOnUiThread(() -> {
});
}
} catch (NoSuchAlgorithmException | IOException i) {
activity.runOnUiThread(() -> Toast.makeText(activity, activity.getResources().getString(R.string.error_download), Toast.LENGTH_SHORT).show());
metaModelFile.delete();
Log.w("whoBIRD", activity.getResources().getString(R.string.error_download), i);
}
});
thread.start();
try {
thread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else {
Log.d("whoBIRD", "meta file exists");
activity.runOnUiThread(() -> {
});
}
}
}

View File

@@ -1,59 +0,0 @@
package com.birdsounds.identify;
import android.Manifest;
import android.content.Context;
import android.content.pm.PackageManager;
import android.location.LocationListener;
import android.location.LocationManager;
import android.os.Bundle;
import android.widget.Toast;
import androidx.core.app.ActivityCompat;
public class Location {
private static LocationListener locationListenerGPS;
static void stopLocation(Context context){
LocationManager locationManager = (LocationManager) context.getSystemService(Context.LOCATION_SERVICE);
if (locationListenerGPS!=null) locationManager.removeUpdates(locationListenerGPS);
locationListenerGPS=null;
}
static void requestLocation(Context context, SoundClassifier soundClassifier) {
if (ActivityCompat.checkSelfPermission(context, Manifest.permission.ACCESS_COARSE_LOCATION) == PackageManager.PERMISSION_GRANTED && checkLocationProvider(context)) {
LocationManager locationManager = (LocationManager) context.getSystemService(Context.LOCATION_SERVICE);
if (locationListenerGPS==null) locationListenerGPS = new LocationListener() {
@Override
public void onLocationChanged(android.location.Location location) {
soundClassifier.runMetaInterpreter(location);
}
@Deprecated
@Override
public void onStatusChanged(String provider, int status, Bundle extras) {
}
@Override
public void onProviderEnabled(String provider) {
}
@Override
public void onProviderDisabled(String provider) {
}
};
locationManager.requestLocationUpdates(LocationManager.GPS_PROVIDER, 60000, 0, locationListenerGPS);
}
}
public static boolean checkLocationProvider(Context context) {
LocationManager locationManager = (LocationManager) context.getSystemService(Context.LOCATION_SERVICE);
if (!locationManager.isProviderEnabled(LocationManager.GPS_PROVIDER)){
Toast.makeText(context, "Error no GPS", Toast.LENGTH_SHORT).show();
return false;
} else {
return true;
}
}
}

View File

@@ -1,57 +0,0 @@
package com.birdsounds.identify
import android.content.Intent
import android.util.Half.abs
import android.util.Log
import androidx.localbroadcastmanager.content.LocalBroadcastManager
import com.google.android.gms.wearable.MessageEvent
import com.google.android.gms.wearable.WearableListenerService
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.ShortBuffer
class MessageListenerService : WearableListenerService() {
private val tag = "MessageListenerService"
// fun placeSoundClassifier(soundClassifier: SoundClassifier)
override fun onMessageReceived(p0: MessageEvent) {
super.onMessageReceived(p0)
// MainActivity
val soundclassifier = MainActivity.soundClassifier
if (soundclassifier == null) {
Log.w(tag, "Have invalid sound classifier")
return
} else {
Log.w(tag, "Have valid classifier")
}
val short_array = ShortArray(48000 * 3)
var tstamp_bytes = p0.data.copyOfRange(0, Long.SIZE_BYTES)
var audio_bytes = p0.data.copyOfRange(Long.SIZE_BYTES, p0.data.size)
var string_send: String = ""
ByteBuffer.wrap(audio_bytes).order(
ByteOrder.LITTLE_ENDIAN
).asShortBuffer().get(short_array)
Log.w(tag, short_array.sum().toString())
var sorted_list = soundclassifier.executeScoring(short_array)
Log.w(tag, "")
for (i in 0 until 5) {
val score = sorted_list[i].value
val index = sorted_list[i].index
val species_name = soundclassifier.labelList[index]
Log.w(tag, species_name + ", " + score.toString())
string_send+= species_name
string_send+=','
string_send+=score.toString()
string_send+=';'
}
MessageSenderFromPhone.sendMessage("/audio", tstamp_bytes + string_send.toByteArray(), this)
// Log.i(tag , short_array.map( { abs(it)}).sum().toString())
// Log.i(tag, short_array[0].toString())
// Log.i(tag, p0.data.toString(Charsets.US_ASCII))
// broadcastMessage(p0)
}
}

View File

@@ -0,0 +1,101 @@
import android.content.ContentValues.TAG
import android.media.MediaCodec
import android.media.MediaCodecInfo
import android.media.MediaCodecList
import android.media.MediaFormat
import android.util.Log
import java.nio.ByteBuffer
fun listMediaCodecDecoders() {
val codecList = MediaCodecList(MediaCodecList.ALL_CODECS) // Get all codecs
val codecs = codecList.codecInfos
Log.e(TAG, "Available MediaCodec Decoders:")
for (codec in codecs) {
if (!codec.isEncoder) { // Check if the codec is a decoder
Log.e(TAG, "Decoder: ${codec.name}")
// List the MIME types supported by the decoder
val supportedTypes = codec.supportedTypes
Log.e(TAG, " Supported Types:")
for (type in supportedTypes) {
Log.e(TAG, " $type")
}
}
}
}
fun decodeAACToPCM(inputData: ByteArray): ShortArray {
// listMediaCodecDecoders();
// Media format configuration for AAC
val mediaFormat = MediaFormat.createAudioFormat(
MediaFormat.MIMETYPE_AUDIO_OPUS, // MIME type for AAC
48000, // Sample rate, change this based on your input data
1 // Channel count, change this based on your input data
)
// mediaFormat.setInteger(MediaFormat.KEY_BIT_RATE, 64000) // 128kbps
// mediaFormat.setInteger(MediaFormat.KEY_IS_ADTS, 0) // AAC should use ADTS header
// mediaFormat.setInteger(MediaFormat.KEY_AAC_PROFILE, MediaCodecInfo.CodecProfileLevel.AACObjectLC)
// Create a decoder for AAC
val mediaCodec = MediaCodec.createDecoderByType( MediaFormat.MIMETYPE_AUDIO_OPUS);
mediaCodec.configure(mediaFormat, null, null, 0)
mediaCodec.start()
val decodedSamples = mutableListOf<Short>()
// Variables for handling input and output buffers
val bufferInfo = MediaCodec.BufferInfo()
var inputOffset = 0;
while (inputOffset < inputData.size || true) {
// Feed input data to the codec
val inputBufferIndex = mediaCodec.dequeueInputBuffer(100000) // Timeout in microseconds
if (inputBufferIndex >= 0 && inputOffset < inputData.size) {
val inputBuffer: ByteBuffer? = mediaCodec.getInputBuffer(inputBufferIndex)
inputBuffer?.clear()
// Calculate the number of bytes to write to the buffer
val chunkSize = kotlin.math.min(inputBuffer?.capacity() ?: 0, inputData.size - inputOffset)
Log.e(TAG, "Chunk size: " + chunkSize.toString())
inputBuffer?.put(inputData, inputOffset, chunkSize)
inputOffset += chunkSize
// Pass the data to the codec
mediaCodec.queueInputBuffer(inputBufferIndex, 0, chunkSize, 0, 0)
}
// Process output data
val outputBufferIndex = mediaCodec.dequeueOutputBuffer(bufferInfo, 10000)
Log.e(TAG, "Output buffer index: " + outputBufferIndex.toString())
if (outputBufferIndex >= 0) {
val outputBuffer: ByteBuffer? = mediaCodec.getOutputBuffer(outputBufferIndex)
// Convert byte buffer to PCM data (16-bit integers)
val pcmData = ShortArray(bufferInfo.size / 2)
outputBuffer?.asShortBuffer()?.get(pcmData)
// Add PCM data to the final output array
decodedSamples.addAll(pcmData.toList())
// Release the output buffer
mediaCodec.releaseOutputBuffer(outputBufferIndex, false)
} else if (outputBufferIndex == MediaCodec.INFO_OUTPUT_FORMAT_CHANGED) {
// Handle format changes, if needed
mediaCodec.outputFormat
} else if (outputBufferIndex == MediaCodec.INFO_TRY_AGAIN_LATER) {
// Output buffer not available, retry later
if (inputOffset >= inputData.size) break
}
}
// Release the codec when done
mediaCodec.stop()
mediaCodec.release()
return decodedSamples.toShortArray()
}

View File

@@ -0,0 +1,52 @@
package com.birdsounds.identify
import android.content.Context
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
class Downloader(mainActivity: MainActivity) {
private val settings = Settings();
private var activity: MainActivity = mainActivity;
private var context: Context = activity.applicationContext;
fun copyAssetToFolder(assetName: String, destinationPath: String): Boolean {
try {
// Get the input stream from the asset
val assetInputStream = context.assets.open(assetName)
// Create the destination directory if it doesn't exist
val destinationFile = File(activity.getDir("",Context.MODE_PRIVATE).absolutePath + "/" + destinationPath)
destinationFile.parentFile?.mkdirs()
// Copy the file
val buffer = ByteArray(1024)
val outputStream = FileOutputStream(destinationFile)
var read: Int
while (assetInputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
assetInputStream.close()
outputStream.flush()
outputStream.close()
return true
} catch (e: IOException) {
e.printStackTrace()
return false
}
}
fun prepareModelFiles()
{
copyAssetToFolder(settings.pkg_model_file, settings.local_model_file);
copyAssetToFolder(settings.pkg_meta_model_file, settings.local_meta_model_file);
}
}

View File

@@ -0,0 +1,71 @@
package com.birdsounds.identify
import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
import android.location.Location
import android.location.LocationListener
import android.location.LocationManager
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.core.app.ActivityCompat
object Location {
private var locationListenerGPS: LocationListener? = null
fun stopLocation(context: Context) {
val locationManager = context.getSystemService(Context.LOCATION_SERVICE) as LocationManager
if (locationListenerGPS != null) locationManager.removeUpdates(
locationListenerGPS!!
)
locationListenerGPS = null
}
fun requestLocation(context: Context, soundClassifier: SoundClassifier) {
if (ActivityCompat.checkSelfPermission(
context,
Manifest.permission.ACCESS_COARSE_LOCATION
) == PackageManager.PERMISSION_GRANTED && checkLocationProvider(context)
) {
val locationManager =
context.getSystemService(Context.LOCATION_SERVICE) as LocationManager
if (locationListenerGPS == null) locationListenerGPS = object : LocationListener {
override fun onLocationChanged(location: Location) {
Log.w(TAG, "Got location changed");
while (!soundClassifier.is_model_ready()) {
Thread.sleep(50);
}
Log.w(TAG, "Sound classifier is ready");
soundClassifier.runMetaInterpreter(location)
soundClassifier.runMetaInterpreter(location)
}
@Deprecated("")
override fun onStatusChanged(provider: String, status: Int, extras: Bundle) {
}
override fun onProviderEnabled(provider: String) {
}
override fun onProviderDisabled(provider: String) {
}
}
locationManager.requestLocationUpdates(
LocationManager.GPS_PROVIDER, 60000, 0f,
locationListenerGPS!!
)
}
}
fun checkLocationProvider(context: Context): Boolean {
val locationManager = context.getSystemService(Context.LOCATION_SERVICE) as LocationManager
if (!locationManager.isProviderEnabled(LocationManager.GPS_PROVIDER)) {
Toast.makeText(context, "Error no GPS", Toast.LENGTH_SHORT).show()
return false
} else {
return true
}
}
}

View File

@@ -3,7 +3,6 @@ package com.birdsounds.identify
import android.content.pm.PackageManager
import android.os.Bundle
import android.Manifest
import android.annotation.SuppressLint
import android.util.Log
import androidx.activity.enableEdgeToEdge
import androidx.appcompat.app.AppCompatActivity
@@ -13,6 +12,13 @@ import androidx.core.view.WindowInsetsCompat
import com.google.android.gms.wearable.ChannelClient
import com.google.android.gms.wearable.Wearable
val Any.TAG: String
get() {
val tag = javaClass.simpleName
return if (tag.length <= 23) tag else tag.substring(0, 23)
}
class MainActivity : AppCompatActivity() {
// private lateinit var soundClassifier: SoundClassifier
val REQUEST_PERMISSIONS = 1337
@@ -26,15 +32,17 @@ class MainActivity : AppCompatActivity() {
.registerChannelCallback(object : ChannelClient.ChannelCallback() {
override fun onChannelOpened(channel: ChannelClient.Channel) {
super.onChannelOpened(channel)
Log.d("HEY", "onChannelOpened")
Log.d(TAG, "onChannelOpened")
}
}
)
Downloader.downloadModels(this)
Downloader(this).prepareModelFiles();
Log.w(TAG, "Finished setting up downloader")
requestPermissions()
soundClassifier = SoundClassifier(this, SoundClassifier.Options())
Location.requestLocation(this, soundClassifier)
Log.w(TAG, "Starting sound classifier")
Location.requestLocation(this, soundClassifier!!)
Log.w(TAG, "Starting location requester")
ViewCompat.setOnApplyWindowInsetsListener(findViewById(R.id.main)) { v, insets ->
val systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars())
v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom)
@@ -60,6 +68,6 @@ class MainActivity : AppCompatActivity() {
perms.add(Manifest.permission.ACCESS_COARSE_LOCATION)
perms.add(Manifest.permission.ACCESS_FINE_LOCATION)
}
if (!perms.isEmpty()) requestPermissions(perms.toTypedArray(), REQUEST_PERMISSIONS)
if (perms.isNotEmpty()) requestPermissions(perms.toTypedArray(), REQUEST_PERMISSIONS)
}
}

View File

@@ -1,5 +1,4 @@
package com.birdsounds.identify
import android.content.Intent
object MessageConstants {
const val intentName = "WearableMessageDisplay"

View File

@@ -0,0 +1,56 @@
package com.birdsounds.identify
import android.util.Log
import com.google.android.gms.wearable.MessageEvent
import com.google.android.gms.wearable.WearableListenerService
import decodeAACToPCM
class MessageListenerService : WearableListenerService() {
// fun placeSoundClassifier(soundClassifier: SoundClassifier)
override fun onMessageReceived(p0: MessageEvent) {
super.onMessageReceived(p0)
// MainActivity
Log.w(TAG, "Data recv: "+p0.data.size.toString() + " bytes")
val soundclassifier = MainActivity.soundClassifier
if (soundclassifier == null) {
Log.w(TAG, "Have invalid sound classifier")
return
} else {
Log.w(TAG, "Have valid classifier")
}
var tstamp_bytes = p0.data.copyOfRange(0, Long.SIZE_BYTES)
var audio_bytes = p0.data.copyOfRange(Long.SIZE_BYTES, p0.data.size)
var string_send: String = ""
val pcm_byte_array = decodeAACToPCM(audio_bytes)
Log.e(TAG,"Size of short array buffer: "+ pcm_byte_array.size.toString());
// ByteBuffer.wrap(audio_bytes).order(
// ByteOrder.LITTLE_ENDIAN
// ).asShortBuffer().get(short_array)
Log.e(TAG, pcm_byte_array.sum().toString())
Log.e(TAG, "STARTING SCORING");
// var sorted_list = soundclassifier.executeScoring(short_array)
// Log.w(TAG, "FINISHED SCORING");
// Log.w(TAG, "")
// for (i in 0 until 5) {
// val score = sorted_list[i].value
// val index = sorted_list[i].index
// val species_name = soundclassifier.labelList[index]
// Log.w(TAG, species_name + ", " + score.toString())
// string_send+= species_name
// string_send+=','
// string_send+=score.toString()
// string_send+=';'
// }
MessageSenderFromPhone.sendMessage("/audio", tstamp_bytes + string_send.toByteArray(), this)
}
}

View File

@@ -0,0 +1,9 @@
package com.birdsounds.identify
class Settings {
var local_model_file: String = "2024_08_16_audio_model.tflite"
var pkg_model_file: String = "2024_08_16/audio-model.tflite"
var local_meta_model_file: String = "2024_08_16_meta_model.tflite"
var pkg_meta_model_file: String = "2024_08_16/meta-model.tflite"
}

View File

@@ -1,11 +1,11 @@
package com.birdsounds.identify
import android.content.Context
import android.location.Location
import android.os.SystemClock
import android.preference.PreferenceManager
import android.util.Log
import androidx.annotation.Nullable
import org.tensorflow.lite.Interpreter
import java.io.BufferedReader
import java.io.File
@@ -23,8 +23,6 @@ import kotlin.concurrent.scheduleAtFixedRate
import kotlin.math.ceil
import kotlin.math.cos
import uk.me.berndporr.iirj.Butterworth
import java.nio.ShortBuffer
import kotlin.math.round
import kotlin.math.sin
@@ -32,7 +30,11 @@ class SoundClassifier(
context: Context,
private val options: Options = Options()
) {
internal var mContext: Context
val TAG = "Sound Classifier"
init {
@@ -40,14 +42,10 @@ class SoundClassifier(
}
class Options(
/** Path of the converted model label file, relative to the assets/ directory. */
val labelsBase: String = "labels",
/** Path of the converted .tflite file, relative to the assets/ directory. */
val assetFile: String = "assets.txt",
/** Path of the converted .tflite file, relative to the assets/ directory. */
val modelPath: String = "modelfx.tflite",
/** Path of the meta model .tflite file, relative to the assets/ directory. */
val metaModelPath: String = "metaModelfx.tflite",
/** The required audio sample rate in Hz. */
val sampleRate: Int = 48000,
/** Multiplier for audio samples */
@@ -83,7 +81,7 @@ class SoundClassifier(
/** Number of output classes of the TFLite model. */
private var modelNumClasses = 0
private var metaModelNumClasses = 0
private var settings: Settings = Settings();
/** Used to hold the real-time probabilities predicted by the model for the output classes. */
private lateinit var predictionProbs: FloatArray
@@ -94,19 +92,26 @@ class SoundClassifier(
private var recognitionTask: TimerTask? = null
/** Buffer that holds audio PCM sample that are fed to the TFLite model for inference. */
private lateinit var inputBuffer: FloatBuffer
private lateinit var metaInputBuffer: FloatBuffer
init {
private var model_ready = false;
init {;
setupDecoder(context)
loadLabels(context)
loadAssetList(context)
setupInterpreter(context)
setupMetaInterpreter(context)
warmUpModel()
this.model_ready = true;
}
fun is_model_ready(): Boolean
{
return this.model_ready;
}
private fun setupDecoder(context: Context) {
}
/** Retrieve asset list from "asset_list" file */
private fun loadAssetList(context: Context) {
@@ -168,10 +173,12 @@ class SoundClassifier(
private fun setupInterpreter(context: Context) {
try {
val modelFilePath = context.getDir(
"filesdir",
val modelFilePath =
context.getDir(
"",
Context.MODE_PRIVATE
).absolutePath + "/" + options.modelPath
).absolutePath + "/" + settings.local_model_file;
Log.i(TAG, "Trying to create TFLite buffer from $modelFilePath")
val modelFile = File(modelFilePath)
val tfliteBuffer: ByteBuffer =
@@ -211,9 +218,9 @@ class SoundClassifier(
try {
val metaModelFilePath = context.getDir(
"filesdir",
"",
Context.MODE_PRIVATE
).absolutePath + "/" + options.metaModelPath
).absolutePath + "/" + settings.local_meta_model_file
Log.i(TAG, "Trying to create TFLite buffer from $metaModelFilePath")
val metaModelFile = File(metaModelFilePath)
val tfliteBuffer: ByteBuffer =
@@ -244,6 +251,7 @@ class SoundClassifier(
}
// Fill the array with 1 initially.
metaPredictionProbs = FloatArray(metaModelNumClasses) { 1f }
metaInputBuffer = FloatBuffer.allocate(metaModelInputLength)
}
@@ -333,7 +341,6 @@ class SoundClassifier(
else inputBuffer.put(i, butterworth.filter(s.toDouble()).toFloat())
}
inputBuffer.rewind()
outputBuffer.rewind()
interpreter.run(inputBuffer, outputBuffer)