Files
geo_birds_list/run_me.py
2026-01-05 20:28:51 -05:00

247 lines
7.3 KiB
Python

import tensorflow as tf
import numpy as np
interpreter = tf.lite.Interpreter(model_path="geo_v46.tflite")
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.allocate_tensors()
print(input_details)
# This prints the following. Output is trimmed to not overwhelm
# [{'name': 'serving_default_longitude:0',
# 'index': 0,
# 'shape': array([1], dtype=int32),
# 'dtype': numpy.float32 }
#
# {'name': 'serving_default_week_of_year:0',
# 'index': 1,
# 'shape': array([1], dtype=int32),
# 'dtype': numpy.float32 }
#
# {'name': 'serving_default_latitude:0',
# 'index': 2,
# 'shape': array([1], dtype=int32),
# 'dtype': numpy.float32 }
print(output_details)
# This prints the following. Output is trimmed to not overwhelm
# [{'dtype': <class 'numpy.float32'>,
# 'index': 88,
# 'shape': array([ 1, 2728], dtype=int32)}]
# %%
with open('labels_species.txt','r') as ff:
output_label = ff.read().split('\n')
# This is for Ann Arbor
print('\n\n')
print('Ann Arbor')
interpreter.set_tensor(0, [np.float32( -83.75) ])
interpreter.set_tensor(1, [np.float32( 1.0 ) ]) #First week of the year
interpreter.set_tensor(2, [np.float32( 42.29 )] )
interpreter.invoke()
output_probabilities_ish = interpreter.get_tensor(88).squeeze()
sorted_idx = output_probabilities_ish.argsort()[::-1]
for id_rank, rank in zip(sorted_idx, range(10)):
print(rank, output_label[id_rank], output_probabilities_ish[id_rank])
# Ann Arbor
# 0 Junco hyemalis 0.32183477
# 1 Cardinalis cardinalis 0.31213352
# 2 Passer domesticus 0.28218403
# 3 Dryobates pubescens 0.27699217
# 4 Poecile atricapillus 0.2743776
# 5 Corvus brachyrhynchos 0.27222255
# 6 Zenaida macroura 0.26566097
# 7 Branta canadensis 0.2610584
# 8 Melanerpes carolinus 0.25785306
# 9 Sitta carolinensis 0.25774997
# Repeat for New Zealand in North Island near Tāne Mahuta
print('\n\n')
print('New Zealand')
interpreter.set_tensor(0, [np.float32( 173.53 ) ])
interpreter.set_tensor(1, [np.float32( 1.0 ) ]) #First week of the year
interpreter.set_tensor(2, [np.float32( -35.59 )] )
interpreter.invoke()
output_probabilities_ish = interpreter.get_tensor(88).squeeze()
sorted_idx = output_probabilities_ish.argsort()[::-1]
for id_rank, rank in zip(sorted_idx, range(10)):
print(rank, output_label[id_rank], output_probabilities_ish[id_rank])
# New Zealand
# 0 Acridotheres tristis 0.6020632
# 1 Passer domesticus 0.47801244
# 2 Hirundo neoxena 0.47301206
# 3 Turdus merula 0.42792457
# 4 Chroicocephalus novaehollandiae 0.38255733
# 5 Todiramphus sanctus 0.35976177
# 6 Zosterops lateralis 0.3565269
# 7 Prosthemadera novaeseelandiae 0.35628808
# 8 Larus dominicanus 0.34348693
# 9 Gerygone igata 0.33886477
#Repeat for Osa peninsula near guest house
print('\n\n')
print('Osa Peninsula')
interpreter.set_tensor(0, [np.float32( -83.34 ) ])
interpreter.set_tensor(1, [np.float32( 1.0 ) ]) #First week of the year
interpreter.set_tensor(2, [np.float32( 8.40 )] )
interpreter.invoke()
output_probabilities_ish = interpreter.get_tensor(88).squeeze()
sorted_idx = output_probabilities_ish.argsort()[::-1]
for id_rank, rank in zip(sorted_idx, range(10)):
print(rank, output_label[id_rank], output_probabilities_ish[id_rank])
# Osa Peninsula
# 0 Ara macao 0.5610521
# 1 Ramphastos ambiguus 0.52797556
# 2 Ramphocelus passerinii 0.45787573
# 3 Tyrannus melancholicus 0.34361967
# 4 Poliocrania exsul 0.33964986
# 5 Cantorchilus semibadius 0.3127423
# 6 Thamnophilus bridgesi 0.30603442
# 7 Daptrius chimachima 0.30467013
# 8 Pitangus sulphuratus 0.3044303
# 9 Amazona autumnalis 0.29769742
# %%
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.patches import Polygon
import geopandas as gpd
# This is a way to see what the model thinks is the distribution of Ara Macao
long_range = [-86, -82.5]
lat_range = [7.5, 12]
plt.close('all')
long_points = np.linspace(*long_range,150)
lat_points = np.linspace(*lat_range, 100)
long_grid, lat_grid = np.meshgrid(long_points, lat_points)
flat_long_grid = long_grid.flatten()
flat_lat_grid = lat_grid.flatten()
# The developers did not enable "batching" so we have to loop through which is very slow
probs = list()
for c_long, c_lat in zip(flat_long_grid, flat_lat_grid):
interpreter.set_tensor(0, [np.float32(c_long)])
interpreter.set_tensor(1, [np.float32(1)])
interpreter.set_tensor(2, [np.float32(c_lat)])
interpreter.invoke()
macaw_prob = interpreter.get_tensor(88).squeeze()[output_label.index('Ara macao')]
probs.append(macaw_prob)
prob_grid = np.reshape(probs, long_grid.shape)
boundaries = gpd.read_file('gadm41_CRI.gpkg')
cr_hull = list(boundaries.geometry)[0].convex_hull
all_coords = [np.asarray(list(x.exterior.coords)) for x in list(boundaries.geometry)[0].geoms]
X = long_grid
Y = lat_grid
Z = prob_grid
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.contourf(X, Y, Z, cmap='Greens')
ax.set_xlim(long_range)
ax.set_ylim(lat_range)
for cr_coords in all_coords:
ax.add_patch(Polygon(cr_coords, alpha=0.5))
ax.set_aspect(1.0)
fig.savefig('output_macaw.png')
# %%
# This is a way to see what the model thinks is the distribution of Ara Macao
long_range = [-130,-60]
lat_range = [22, 50]
plt.close('all')
long_points = np.linspace(*long_range,150)
lat_points = np.linspace(*lat_range, 100)
long_grid, lat_grid = np.meshgrid(long_points, lat_points)
flat_long_grid = long_grid.flatten()
flat_lat_grid = lat_grid.flatten()
# The developers did not enable "batching" so we have to loop through which is very slow
probs = list()
for c_long, c_lat in zip(flat_long_grid, flat_lat_grid):
interpreter.set_tensor(0, [np.float32(c_long)])
interpreter.set_tensor(1, [np.float32(1)])
interpreter.set_tensor(2, [np.float32(c_lat)])
interpreter.invoke()
cardinal_prob = interpreter.get_tensor(88).squeeze()[output_label.index('Cardinalis cardinalis')]
probs.append(cardinal_prob)
prob_grid = np.reshape(probs, long_grid.shape)
boundaries = gpd.read_file('gadm41_USA.gpkg')
cr_hull = list(boundaries.geometry)[0].convex_hull
all_coords = [np.asarray(list(x.exterior.coords)) for x in list(boundaries.geometry)[0].geoms]
X = long_grid
Y = lat_grid
Z = prob_grid
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.contourf(X, Y, Z, cmap='Greens')
ax.set_xlim(long_range)
ax.set_ylim(lat_range)
from matplotlib.collections import PatchCollection
patches = list()
for cr_coords in tqdm(all_coords):
patches.append(Polygon(cr_coords, alpha=0.5))
ax.add_collection(PatchCollection(patches, linewidth=0.5, facecolor='none',edgecolor='black'))
ax.set_aspect(1.0)
fig.savefig('output_cardinal.png')
def create_labels_species():
# This code was used to generate the labels_specis_name file that is read in to understand what species each index corresponds to.
# labels_2025.txt was acquired by unzipping the geo_v46.tflite file
import sqlite3
con = sqlite3.connect("merlin_room")
cursor = con.cursor()
code_name_list = cursor.execute("SELECT speciesCode, scientificName FROM taxon;").fetchall()
code_name_map = {x[0]: x[1] for x in code_name_list}
with open('labels_2025.txt','r') as lab:
labels = lab.read().split('\n')
species_index_label = [code_name_map.get(r,'N/A') for r in labels ]
with open('labels_species.txt','w') as ff:
ff.write('\n'.join(species_index_label))