247 lines
7.3 KiB
Python
247 lines
7.3 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
interpreter = tf.lite.Interpreter(model_path="geo_v46.tflite")
|
|
|
|
# Get input and output tensors.
|
|
input_details = interpreter.get_input_details()
|
|
output_details = interpreter.get_output_details()
|
|
|
|
interpreter.allocate_tensors()
|
|
|
|
print(input_details)
|
|
# This prints the following. Output is trimmed to not overwhelm
|
|
# [{'name': 'serving_default_longitude:0',
|
|
# 'index': 0,
|
|
# 'shape': array([1], dtype=int32),
|
|
# 'dtype': numpy.float32 }
|
|
#
|
|
# {'name': 'serving_default_week_of_year:0',
|
|
# 'index': 1,
|
|
# 'shape': array([1], dtype=int32),
|
|
# 'dtype': numpy.float32 }
|
|
#
|
|
# {'name': 'serving_default_latitude:0',
|
|
# 'index': 2,
|
|
# 'shape': array([1], dtype=int32),
|
|
# 'dtype': numpy.float32 }
|
|
|
|
print(output_details)
|
|
# This prints the following. Output is trimmed to not overwhelm
|
|
# [{'dtype': <class 'numpy.float32'>,
|
|
# 'index': 88,
|
|
# 'shape': array([ 1, 2728], dtype=int32)}]
|
|
|
|
# %%
|
|
with open('labels_species.txt','r') as ff:
|
|
output_label = ff.read().split('\n')
|
|
|
|
|
|
# This is for Ann Arbor
|
|
print('\n\n')
|
|
print('Ann Arbor')
|
|
interpreter.set_tensor(0, [np.float32( -83.75) ])
|
|
interpreter.set_tensor(1, [np.float32( 1.0 ) ]) #First week of the year
|
|
interpreter.set_tensor(2, [np.float32( 42.29 )] )
|
|
interpreter.invoke()
|
|
output_probabilities_ish = interpreter.get_tensor(88).squeeze()
|
|
sorted_idx = output_probabilities_ish.argsort()[::-1]
|
|
|
|
for id_rank, rank in zip(sorted_idx, range(10)):
|
|
print(rank, output_label[id_rank], output_probabilities_ish[id_rank])
|
|
|
|
|
|
# Ann Arbor
|
|
# 0 Junco hyemalis 0.32183477
|
|
# 1 Cardinalis cardinalis 0.31213352
|
|
# 2 Passer domesticus 0.28218403
|
|
# 3 Dryobates pubescens 0.27699217
|
|
# 4 Poecile atricapillus 0.2743776
|
|
# 5 Corvus brachyrhynchos 0.27222255
|
|
# 6 Zenaida macroura 0.26566097
|
|
# 7 Branta canadensis 0.2610584
|
|
# 8 Melanerpes carolinus 0.25785306
|
|
# 9 Sitta carolinensis 0.25774997
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Repeat for New Zealand in North Island near Tāne Mahuta
|
|
print('\n\n')
|
|
print('New Zealand')
|
|
interpreter.set_tensor(0, [np.float32( 173.53 ) ])
|
|
interpreter.set_tensor(1, [np.float32( 1.0 ) ]) #First week of the year
|
|
interpreter.set_tensor(2, [np.float32( -35.59 )] )
|
|
interpreter.invoke()
|
|
output_probabilities_ish = interpreter.get_tensor(88).squeeze()
|
|
sorted_idx = output_probabilities_ish.argsort()[::-1]
|
|
|
|
for id_rank, rank in zip(sorted_idx, range(10)):
|
|
print(rank, output_label[id_rank], output_probabilities_ish[id_rank])
|
|
|
|
|
|
# New Zealand
|
|
# 0 Acridotheres tristis 0.6020632
|
|
# 1 Passer domesticus 0.47801244
|
|
# 2 Hirundo neoxena 0.47301206
|
|
# 3 Turdus merula 0.42792457
|
|
# 4 Chroicocephalus novaehollandiae 0.38255733
|
|
# 5 Todiramphus sanctus 0.35976177
|
|
# 6 Zosterops lateralis 0.3565269
|
|
# 7 Prosthemadera novaeseelandiae 0.35628808
|
|
# 8 Larus dominicanus 0.34348693
|
|
# 9 Gerygone igata 0.33886477
|
|
|
|
|
|
#Repeat for Osa peninsula near guest house
|
|
print('\n\n')
|
|
print('Osa Peninsula')
|
|
interpreter.set_tensor(0, [np.float32( -83.34 ) ])
|
|
interpreter.set_tensor(1, [np.float32( 1.0 ) ]) #First week of the year
|
|
interpreter.set_tensor(2, [np.float32( 8.40 )] )
|
|
interpreter.invoke()
|
|
output_probabilities_ish = interpreter.get_tensor(88).squeeze()
|
|
sorted_idx = output_probabilities_ish.argsort()[::-1]
|
|
|
|
for id_rank, rank in zip(sorted_idx, range(10)):
|
|
print(rank, output_label[id_rank], output_probabilities_ish[id_rank])
|
|
|
|
|
|
# Osa Peninsula
|
|
# 0 Ara macao 0.5610521
|
|
# 1 Ramphastos ambiguus 0.52797556
|
|
# 2 Ramphocelus passerinii 0.45787573
|
|
# 3 Tyrannus melancholicus 0.34361967
|
|
# 4 Poliocrania exsul 0.33964986
|
|
# 5 Cantorchilus semibadius 0.3127423
|
|
# 6 Thamnophilus bridgesi 0.30603442
|
|
# 7 Daptrius chimachima 0.30467013
|
|
# 8 Pitangus sulphuratus 0.3044303
|
|
# 9 Amazona autumnalis 0.29769742
|
|
|
|
|
|
# %%
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib import cm
|
|
from matplotlib.patches import Polygon
|
|
import geopandas as gpd
|
|
|
|
# This is a way to see what the model thinks is the distribution of Ara Macao
|
|
long_range = [-86, -82.5]
|
|
lat_range = [7.5, 12]
|
|
plt.close('all')
|
|
long_points = np.linspace(*long_range,150)
|
|
lat_points = np.linspace(*lat_range, 100)
|
|
|
|
long_grid, lat_grid = np.meshgrid(long_points, lat_points)
|
|
|
|
flat_long_grid = long_grid.flatten()
|
|
flat_lat_grid = lat_grid.flatten()
|
|
|
|
# The developers did not enable "batching" so we have to loop through which is very slow
|
|
probs = list()
|
|
for c_long, c_lat in zip(flat_long_grid, flat_lat_grid):
|
|
interpreter.set_tensor(0, [np.float32(c_long)])
|
|
interpreter.set_tensor(1, [np.float32(1)])
|
|
interpreter.set_tensor(2, [np.float32(c_lat)])
|
|
interpreter.invoke()
|
|
macaw_prob = interpreter.get_tensor(88).squeeze()[output_label.index('Ara macao')]
|
|
probs.append(macaw_prob)
|
|
|
|
prob_grid = np.reshape(probs, long_grid.shape)
|
|
|
|
|
|
boundaries = gpd.read_file('gadm41_CRI.gpkg')
|
|
cr_hull = list(boundaries.geometry)[0].convex_hull
|
|
|
|
all_coords = [np.asarray(list(x.exterior.coords)) for x in list(boundaries.geometry)[0].geoms]
|
|
|
|
|
|
X = long_grid
|
|
Y = lat_grid
|
|
Z = prob_grid
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(1,1,1)
|
|
ax.contourf(X, Y, Z, cmap='Greens')
|
|
ax.set_xlim(long_range)
|
|
ax.set_ylim(lat_range)
|
|
for cr_coords in all_coords:
|
|
ax.add_patch(Polygon(cr_coords, alpha=0.5))
|
|
ax.set_aspect(1.0)
|
|
fig.savefig('output_macaw.png')
|
|
|
|
|
|
# %%
|
|
|
|
# This is a way to see what the model thinks is the distribution of Ara Macao
|
|
long_range = [-130,-60]
|
|
lat_range = [22, 50]
|
|
plt.close('all')
|
|
long_points = np.linspace(*long_range,150)
|
|
lat_points = np.linspace(*lat_range, 100)
|
|
|
|
long_grid, lat_grid = np.meshgrid(long_points, lat_points)
|
|
|
|
flat_long_grid = long_grid.flatten()
|
|
flat_lat_grid = lat_grid.flatten()
|
|
|
|
# The developers did not enable "batching" so we have to loop through which is very slow
|
|
probs = list()
|
|
for c_long, c_lat in zip(flat_long_grid, flat_lat_grid):
|
|
interpreter.set_tensor(0, [np.float32(c_long)])
|
|
interpreter.set_tensor(1, [np.float32(1)])
|
|
interpreter.set_tensor(2, [np.float32(c_lat)])
|
|
interpreter.invoke()
|
|
cardinal_prob = interpreter.get_tensor(88).squeeze()[output_label.index('Cardinalis cardinalis')]
|
|
probs.append(cardinal_prob)
|
|
|
|
|
|
prob_grid = np.reshape(probs, long_grid.shape)
|
|
|
|
|
|
boundaries = gpd.read_file('gadm41_USA.gpkg')
|
|
cr_hull = list(boundaries.geometry)[0].convex_hull
|
|
|
|
all_coords = [np.asarray(list(x.exterior.coords)) for x in list(boundaries.geometry)[0].geoms]
|
|
|
|
|
|
X = long_grid
|
|
Y = lat_grid
|
|
Z = prob_grid
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(1,1,1)
|
|
ax.contourf(X, Y, Z, cmap='Greens')
|
|
ax.set_xlim(long_range)
|
|
ax.set_ylim(lat_range)
|
|
|
|
from matplotlib.collections import PatchCollection
|
|
patches = list()
|
|
for cr_coords in tqdm(all_coords):
|
|
patches.append(Polygon(cr_coords, alpha=0.5))
|
|
ax.add_collection(PatchCollection(patches, linewidth=0.5, facecolor='none',edgecolor='black'))
|
|
|
|
ax.set_aspect(1.0)
|
|
fig.savefig('output_cardinal.png')
|
|
|
|
|
|
|
|
|
|
def create_labels_species():
|
|
# This code was used to generate the labels_specis_name file that is read in to understand what species each index corresponds to.
|
|
# labels_2025.txt was acquired by unzipping the geo_v46.tflite file
|
|
import sqlite3
|
|
con = sqlite3.connect("merlin_room")
|
|
cursor = con.cursor()
|
|
code_name_list = cursor.execute("SELECT speciesCode, scientificName FROM taxon;").fetchall()
|
|
code_name_map = {x[0]: x[1] for x in code_name_list}
|
|
|
|
with open('labels_2025.txt','r') as lab:
|
|
labels = lab.read().split('\n')
|
|
species_index_label = [code_name_map.get(r,'N/A') for r in labels ]
|
|
with open('labels_species.txt','w') as ff:
|
|
ff.write('\n'.join(species_index_label))
|
|
|
|
|
|
|