From f46d1938260683e6d1b7661fa4a3597c136af971 Mon Sep 17 00:00:00 2001 From: "Ishan S. Patel" Date: Thu, 1 Jul 2021 20:26:24 -0400 Subject: [PATCH] yacwc --- __pycache__/bear_utils.cpython-39.pyc | Bin 0 -> 784 bytes __pycache__/data.cpython-39.pyc | Bin 5549 -> 3910 bytes __pycache__/model.cpython-39.pyc | Bin 0 -> 579 bytes __pycache__/transforms.cpython-39.pyc | Bin 8045 -> 8155 bytes __pycache__/utils.cpython-39.pyc | Bin 9178 -> 9204 bytes bear_utils.py | 24 +++++ config.py | 0 data.py | 125 +++++++++++++++++++------ model.py | 11 ++- models/20210701_201220.json | 22 +++++ models/20210701_201841.json | 23 +++++ models/20210701_201853.json | 23 +++++ models/20210701_202607.json | 23 +++++ train.py | 88 +++++++++++------ transforms.py | 110 ++++++++++++++++------ utils.py | 130 +++++++++++++++----------- 16 files changed, 433 insertions(+), 146 deletions(-) create mode 100644 __pycache__/bear_utils.cpython-39.pyc create mode 100644 __pycache__/model.cpython-39.pyc create mode 100644 bear_utils.py create mode 100644 config.py create mode 100644 models/20210701_201220.json create mode 100644 models/20210701_201841.json create mode 100644 models/20210701_201853.json create mode 100644 models/20210701_202607.json diff --git a/__pycache__/bear_utils.cpython-39.pyc b/__pycache__/bear_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a4c4f3c2d2650aa15da4aaae49721af9ccbd512 GIT binary patch literal 784 zcmYe~<>g{vU|^Wxa4*4?iGkrUh=Yuo7#J8F7#J9e4Hy_0QW#Pga~N_NqZo6UqL^}- zqnLA9qF8cSqgWXkQkYVhTNt9)Qdm-0TNt9)Q`mwTG}&K*%+h4N#ax`3cZ)eWzxWnQ zabZ!(E!M=6#5|*0>_v$wnTdJD$slV$NiF34(+`n3sWp0VEIhT^Gc6EGdk&EF}yz49$$StR*1-G1jtyeA~=e%U;4z z!`#eR%TdEo!zRH{%UQ!-!z#g0%T>Zy!VW{P)VX5H}VF+f>WbsP}Ig#P% z4u_YZ7qChlL4r4A;6cZze&6vX+ z#hk*B!kojB%NoTB6=REH1FL6`Vh7U@HHy$91FoMFggc&5-4AQ|; z!&u~3!kWdlfW3qxg(-!3A#*KDo=6Q-Q4(B~C506vTEbbwlET)^)Gt-bT3^DofV+lu zA!99DI712p3j+&7GczMYo`4}kagPB5BN#?9Ff!Dz)iA{Kl<+n))UucGH8a$5fJx38 zwiI?r1{a1{pIWXIj#};-t{To7t`ts51{;PN<`jlfhN2rOT(#VJGBsQ&+>#76%w-Hk zUrP8PyykjFMh1{{4R_JD8ukSOHCziB85s( z$`~dv7J1b0LHLpkHT)^UHVifV^(i8f47CC^0yP|Ej1w4(+CaXl;iwS+Nv4QOGNg#H zGJutX*y555<_ydrpYem`B#`9z5pp2?k_;)5V5fQ1u%$?WofA+iSR)7unOXr@l+@G< zrbtUNEM%$`su4_)kz}Y9t`SU;m1L+D0Xqrg>T3|=7Kqdc)d<&!q{z)-28GcC#=>hQ zLNy!Dq3LMDc(R-jh2hNng}Mc#&C0#l($4Mz<_yjTr;4MV(m4QCBQyhMsZied{x z4MUcg3qw6P@ko|P)i7j>Votn~fuVVPQc`|JYH^hant)zbaeiKuAXsla zFS1Ip$%oifdE9O>Wu}DOV#!O)O`R;ouB9Q#zyQgGnvAy?vmx0FoYBB}jKd};F*hkC z(avXb4!bj36kA?udHm#^?9F_dAbnb(e8`-cQZYGEP-=1tho~`FD~Ql!y2YetaEq}5 z-AFE*oXp~qSF8;zD?{hzDdYr)x!WY5~aG;3D7_XGu|FUNNY2 zxy6)UT%>xHYI0_3aTF&g#N!JROERLEin4AorWJvViQ>ux6{zuXI+O2kD%4AW>|@T%O)t_0selJ*6l-Q~d}bbc$Z^ESXXa&=#K)_G zinAh6amU2Q$i^hW$ic+J#KOqN$o8L&8A7wL@GxqCWEt6**cf>jSr~a3IT%4K9!3@> z7DgT>0Y+^BMgb-sCLTsUMgc|+MlMD!CK*Nfz}T%-s#^A>ALNo7ImEw-G*q|}_^A|;Rtmc*jeL`}vbIS_|EvpBgZzdS{g zr^pdx1~}fBK`b+{baH-PS!z)UG#!-W7bR!h;!G@#FG8T~46cZm03jbmiP$Ymtj*){= zfC&_0px_q(1^0h0W)VgaMiC}1MggWOIdC8Z@tW3qgKY8v5z!z5Y&e1dSyKe61wkek zvw`X_MlLT7X0Vv1agi~|b=;+SnaTMnsi1UU1PbM%5Rfo)fTvFpC@gL<7grV+nS;c* z;F+el2vif^s%OnB%`K?B#gtc2!~#;mlUW>}SeBTXlbDo~dW$V3wJb9^^%e)LCNDAt zDH8U}FY(MR$Vtsj%_~VwaV;v!F9PRENM>R!Ey>I&zQqD^&@FbC#FE6~)Dp0hZ}CEk zB9KgRYDpAFeNKKdIJXpmA}oqECoMB4^%hG`esXb?SbAznd`f0D2weJEgOoUd2pfg{vU|?Wpx|dKW!@%$u#6iX^3=9ko3=9m#0gMa`DGVu$ISf${nlXwA%x8{b zPGLx4%3;Z6jbdel$gxGSr7)y0=dkB;L~(${aIqP4IHNd`%wfr4%jJsV2J2zWf!P7G zlP8KNl_!gL0bdI1LPkagD4Pw$PGLx4&*9G%h!V&Zj1q+E7m5-Bn=KqAoWha9*}@Pd zlEM`wn#!Cdwm^I#L(!d7<^>Wd+zT1=IKV8Pg^W=WDZD9sEeuhT?hGmXDFQ7FDFUg? zSpr#7S<=nSQ8FolDMBp_QL^9v|S zP036y$uH7myv0>gl$ckXmS2=xT&3x472}dxoSB|i9OLT~lbM%TQd*RllUZC6pPQeO zno}H;l30=$pIDk+rA|<3S!Qu&ex4@NE!NVK%$(v|Y^izanR%(nj39SGF%ts=11kdq zgEJ@@7%(s}q%+hq)-Wz$C}CW{w2%Q5_%#egYETj86vl;&Ah8n01uQj;3mJnMRx$YF^4M#*kYaA^9PxdByofw?u*x^HTD2J@Sh(tMcz>vxi#hAhn#gxLC!kog=!WhMz!kWU? z!Vtxh!k)s>!Vtxp!W7J)$#qL8(+{4wToOwXi&IP7{E|UM0?cd?1_lO@?O^|#FflOH zFxD`{Gt@BEFvK&~FxN1|Gu1E_sg*Eiu`FOMVM}3L$W+UcCse~!WC0gr0*RHd*RZ58 zH#7B1)v}gwEa0qRUC3C=7S52uz{0@7(9F!pkSAcsP@H7IzzBws42%pVT+Ix%>?PdI z47D6JY$+^~pinoe<*eaMVU=V^VasMJYN+8%VV7h`;V5G$YAE4>@R}JJ8EV)U@YZmG z(rIBs311C+3TFye3quNbFHn3+=ij>TMb)^ z1lT13wfr^wDSRN?U?E(?pCT#Au#l-%poTw1N|K>gu!cWHT9Tnw2<#q)35-Pz5VIBt z)(F%H)(EA@%wYzF%>>56h7$f7js-#s8NeYT1`3fFrdqyQ;TojnOYph zTU=U@S_BoSjAC&tOD(>|l9!m9s>xU-;#icNk(ra9oKae=ker`WnpBioj8X)GibGH_ z7z`>oKFnk8( zn<`z*Y^4_;pP83g5+AQ;larsEm{V-02h*p?c#AnZwels%4Y%0xQp@8rQ*JS3rrhFA z%*)F!Ni50C&nv#g4f7jFmLH@FS%4!GRJ(wrSU|!166BX6kby5j1;Z^agh{v9z)Fj6 zai(SFq{f3n=oVW>YG!&y$t~9M%#@N0P3BuH1&JjYMWCvn$dG}7;TCgdN(D-Ify@V) z3=S^=aFtQR5X%9oHYPF^vIIlI3mjCMOt+Zy3~n)2fO!xC9B^DVpn@?uKexaRRGbwl zf=Xz8B7-ioxCE?IQ}Py9Sz=CR3OJaGKu!TyA-9V=5g+95 zA99N&zaTa577Hl7-D1hfPfWSR0&+-^I7l^XK~ZL2NfD^>xW!hSUsRHsa*H!HuQWHc zD6u3JT=w2#D=kP#EJ?k^npu*XTU?|EN*=Nx&77Gj74aqcplD#uNzJ>(omZM0pPZ9e zT%1}A;S?w47UZNBYl_@rDNfBvD>7$bU_dxJiWB6hcu?R+F%@OqVoWOnnH$9gj5vn!pO$R@}Gr;hfxD0%EZG6g-imB0*ri2JYW$4 zMh-?UMlL29Mj4P8xWeK9i%Bs_F!F#*V-#W(V62kH6VE6q4U~vM7+e9dF))A%9)@6s zA~6OA22I8ycL?PH3P-S4zyvsbvB$^fq~^uPXMnr|O0W!!RZ>{31uIJiD*^eMfq?-O z|Db9RZ0k9wt)MEtmZ^pzixE`m*D}{IgUX0nmJ+4~%rz_v85gjmFfC-PWvyXOVU}d5 zWrI{%wd^(QwHzg^%?!1iC2R}WYdF#w7czpXFdK#vwk-AJ zGQ}{}vet5f*lD1upVtqZCT=kYxjPlj&nrtUDuL## zlKi6Nj9Z+E#qlMOCLwEPUWu8>EjAdpDzzxT7@T`Gxr)Fs3QFQp%$X@EQS8a_;7o9f zBQrN0oI{FyL00;J2$0*00zp=?l_VCWrgEX1l-261zZgIqeOVRlFTU_y=7J6w;YJB`H zp7{8}(!?B4`5GTzYP$K7GYYnJR z!r~VK>TwhlrGl&0)RZESwVIr_xWO?5=`q}5gY-3Sai*k}q$Y#vpj+Z;iNz(UMMcSZ zdGST5#d)bErUvn81$nnPi}Ew$Gg1>%if{2IA@T8pxTSnj!Y4#>M!C1xr84?phXd=4>29#5Ln+rlj#;eI4I-u^HSqe3-XgQ zZgHfRCFYbSmZXBJHc(;%V_pUZ29W9C=!;}9Ft z1d9n|3D$6?2&D+8i1aeoa;1pYa+e6zaD&^8eYHF#!ZkdgmUju$0+AZ_1)>WXQp8Hc zKw{zyDdOS`3z--hO2o4yYS?QyYnV!SYIssaQzUxX7#WasNTTbg;aVWM5Y!hym@0*? zrbIkTs)oIWxr8T0x|f-eq3{H_)hPo?=-}E_lL=DOf)a;Y5h!Y3f@2y~ifOXk;z~*^ zNzRBb&a6tk#a^6InwFN63aPKl^NX@mi;8b?CFkelfO=$UdEjPM5hxRa^9@D<1j&G# zMW7lgouP&ymJie{%3>&DsbNHF9BDG%;w?%oEX^!REsjqwN=(sYDl!3;VT=W6#aEFT z0|UcWP1WS?z#IGYKPx|8@=bEE*JQrMlv8wz zBR4-cHLs*J_ZCk%xWyNrlA4@YSp=%^(`U^Dft~i6Ot(0TOHvCU!NHoIn46pU5|s0z zM8G*Q9-_54zPKc@s3g9MMX#VFLz6Lz3*2&wPb*3-jFLoES&&~0?wx1m#YahCkpf$% z$#{z`CABOwIaN~{T=0Rr>_uT9e}js|TP&bqk6T=zeyUG?VoGXJQ4vTk2^3};NJW<( z*oI;~c)_IyDYxQ51sAyF0hb)Nm`d`CszI8$3KEMFb5lz|5zU%kP?DK@i#gcc8VBbAe}75iDjv`xQa_aNdwf~zr|c!QdGna za#s#SK}l*sQ7))H<&EM7b$LLZC{8Vj;)RqLFs>ki8xLxerWQriaktCLTr(Mj0kCMj0j$T?8td zLA79RVrE`^yrxo-1;}Nrd8N4pmA9Dk3czuAi#fp4rwCL+L;O((3L{W8Pz16cTMR~CaCC*VeSlt@u(T53^hUUDi@j23|s1UQ(ZI6xr?8Vrl#1oeF4LBm{-lw$?5 z0W^Mbizl-ErFytP@Pf)iaKzlfxT7~2Qmmz z6@Xdb>WRZ9Hy@NZ>_FZswgU~7FoN0=pzwjTj$u3wMhg{vU|_fwb1z{ZBLl-@5C<7EGcYhXFfcF_D={!Iq%fo~<}gGtrZA*1)pq_Czir?9jzMzN$Y1~X`~y#(p@(`39Q=$2Sq66EaX7f_U%l9^nRUzE%Ql7nJa z1_lO@;m#n#%@`OMN*J;jvzW4&YZxUNYME-7QW#U1dKqh(YnZcGvRJd&ve;``N;qm* zY8aZCKw|N%HO#>bn#_J7nv79g1x2YPMTwbtsVSPAx41!;Wt3$WXXfYKV$01>NzEya z;!H^`Nlh*R330^7gIp9JAH`XepBbN#nwU}?#ha905nljvYZPa4PH}v3a(+?jE$+;` z__Wl-lG38o;v$eAG+Cp#^Gb8$lXDV_i&KlExWHD#msA#{M)4$P5X%Q%dqcrd-(MwA#E=etlFG|kK n1A7i^HrT011SbOn!z~V*-29Z%oK!nd7!L0AjMl|ka_AVM0IkOOfQ7#J8t7#J9eSr`}? zI2bvYixeg&u+*}tfFxBnKV}hPWQ>~pjaAN79%M9IkrGG(A`4>4FfcIOVvH{`1aS}+ zfXu$do|a!!o>-Iu(pfA4@(T0hNH!~O4Uk$*5TP}B9-FL@C5Va83ub`~DAHzNU;r_S zLE*r`#PFYqQGikGUy=6Y1eTP^{_Nb0I+LT>P3l2j*JLU(1-XT@BtO0+HLp0oNRt`e zv0#(I1jy7Pki8&xgZ<9H%*9wG5RxAP)~g3~x+bGvkrgiI*)T9L7%(s}2%@{sW^w|H z#AZ{DYQ}ndkY+~)28Q%BFnDI4e)bFqysgzz!Q4v?>lg+Q^v zIoXFxlTl}L8P^(S0gz*gct9qATwA0!*@(}B(P(ltpGG|>T#L*=PGQc>E4js9l$e(S zVld~V<`wyX#950H^U_nn4lF{r%ot=E!XM0ud6l=gOY)17Gs-fHGxPIqG3OT)7wIuD zFhsE@3spIdCnIf=Oi@wZr#i}DLJMInIzNrsah`Srvs zK+a$+iek-6EsrmUP#IB^Tlrl*!4|M&=9d;jLIM#SprG&tCH?|Xa4>T)axn2Q@-Q+n zvM{p!XZp{^%EQRT#K$7V$i>9=pNV-gr+}g~liYtMW-dl1W*$Zkuu2|A7Dkr8EdN+I zCR+;VXn`Xf9GKv6bYfs&@GEi##T=ubCSMV9+V2+-t&anF5fm*&%pevgh_D6`93TSZ zz+1eKs4PxS&B=kqCfE-krx$@!vK<2h11JQF)j+wJnSl`we=;#KF+w00Bhya-QFtZ- z=}DfP!Xh$RS#TCV$i?ZPP+(vXa42$|d`3`$(PZ*d!A4OpkV6mzBZ%uTxmw7$-W$XO zyS)s=DhCmfAR-1tn1KAu0*NSuJV>+%>~OFbz^(@wF&7jKpn`~jjf;_kk%N(qk&RJ? zk%^JxKMQk}SWseKN`7usetxchX^CEZd}dx|$z)C85`A#iXDk8*qbAcWCOv~&j1_3n ze2dd2JvA@2qM*ob)8sS4o+cpIXflCac8e{uIKCh=2VDz)e0pk0d|GBsPJCHnPHF18 z$(|xwj5?c3L`<33{6XFen7m$0)Cp=7IH7`*ET}xt1E*rGk`$>??xXc8jGfGqqfk z5A2m9NNxbB$^a3WAR>EmtGHr)4v3ixBJw~)K8OIBbc?YJ&Ew!QaSJFe{6L9;88bTw zFmf>QF$pjV{bvKoGI24AFp4npF!C_*FtRZTFi9}-F^Vy={byli`e(o>@}G%GVzQuw zQhh1NQc#xAWc1S%M9v2w2~YuF#13K=fe3Ko2PKOlaOl;8!i23fue3Ndr3e&2=$QqS zPl}2`I>D(5On?&^IDM6XQsNwtr$Af|NcQ-{#KgtO#mMwWK#+}*iw&MFN- z7#%0qNSZP_PTnM`%V;wBfn>2}FUTOqqAm#41EIP>6lW1UGu&bfiJDv^W#J4;>YC6B z6QrdGq_PO4s|aLJ5l9|eR(Jt&DX1i7;F$bfN}o|@v$}LPqdX+lf&9n~BEb0&RE$Fl z?a3Quq}_-~a3IZKtNTH$2_Rw;h?oc>Kt+AjWFc8)tws>n7F0I#l;nf52qZ0Pa-b)+ z?8))6I)=3%<=}7!6QFRQa9x%Wu+#U6oJ}O;5N%GHsAb|(wx*HPz#|5 eRL~ZI(qIwDJfsW<38~5U^6KnDLL5RVLf!yeh2ez& delta 2384 zcmca@|JIH#k(ZZ?fq{YH_4L~bA5}N#_#^^oSk12ssVX`8J$mGLJHH`9; z9hkFul^7TpidY#K7&KXmlqc_CR$!E$e3LoMPz5Be0V1SP33(7#k%56hgn@yfn1z9X zfrF8Qxkz!cB8SA}3YJthRgeO;&5v1x7#VL({>CcjssJ*ctw?ZY~aL{BbG6Ok^vm`&h z1T3q`jP7c%6<`8v56A%qAbUUo!NA1DSS0{6Sr6(5O-8>W8(i+RV_;y=XJBC9FGg{w z-DX9O$&B@mAZ^YJ3=HXMVDQX7{p=YKcw=AW1PWA+Tdc*U1*t_v;OKG#DR&1E9w5RK z86cXJij8fF2IV~Th{CV*U9q(9k=&x6r;ay_308z@AJEGDnz6Ib&CIf%6=F)uw8 z?3^NmJ4`@&A+DbMici`B6ehRWlkzL#GxJMtu|Z;^C=8@397IHcjABd9Nz5&Xzr~VV zlwY7J3h^5xflYSh*Aue>DPk7ABz+t7b6GbWN`sSentgGE=DG19!4ofj{h7?lY<3x z;=nNm4h?WXfDH00as>rFqn{>U5mNdExs5d~CqJ>I2$Wxo*g@V(01=>YE@B3;I6;Ih zhybOIBJ0U_1!T4TL40uHb6{X#00nO`C_*?G85kMvg)Z*khB zr{<+r6cpJ_p8QYPlLO=$O{OB{$^If@VxXYA#UG!ZS`rTmfcUb+oYK_($<-oSjJlhb ziI_661%iANH2I;JXf#wOIL(1mpCJPSgI*CR`Dlt4#eftVg9vbR#e!If@Q4R-!3i52 z2B4TKvH*#J!!2QQpt!hy3n&@z#%C7C7i8wdXXYlRr`}>NNK7fp04W67UIe%M7E4)X zYPlvKB!nU99Hc56MC5>oyvf_e6@7|8%wiBx0wPL51jwXYjAdv6RRl`xQ$Y~{$|?*< z30;7ZgNctxgK@Hfgan%c6Vrb;#>rt4%JtPC!@%i>(N9wlEj=S8 zKrw!ctu(K+I5njRF(0T*XVlU=H~?IVfy~Ao!95HN4EsPfBZ@OVP;tfujbo6!rZ^-mrA`)+*X9O! zyoe3tjwzG97ayk(ZZ?fq{WxqQkv}bIKd}q?j3nCTlVuW|Z3emU%uSqx|Gv*5iyilSA3o zGa622Wv^j0o7}`+#%MM9CwnfV-Q;wR1B_0SH8^)O@@{^@xr~v~XLC3AWG2Sq$?|+z zj8T)j_;eYgChy?WWQ>{oh)Ox@-Z+lR80QMe~i&@@(}@3 z#+u2$1U4}GOWWwgdm8}XJBBs#avuc zROAn0v1H|E=4rANfqWJ<`LK|dU=PSB4v;{bEs6)362+O4 zSds{378fKYr^ct{mPE1V-`PvM_Qma)859hLMYrjfshoi;)Qie=_}N<3SPo z#q^(r3#3Pgk&RJ^QI1iFQG}6)k%^Iok%^HFjM*637f2PUjBxD$UC%=|Zl8OX{h5?8G1%6Q+h*bh2ct8n5)`x+Cp&aCRPy*m!3lRM;enFJXp zZxK;2NDUH%&MTouLx2O2?#dEqM4h`6pR=dqb6r6w#&AFw6Leb>K z2;{mVYml%vhyW#uB2cO*0+kI#-5@cLE%SUvM%l@|tj8I(CWo@E zXVjm}%3i~0GP#MpjL~B9Pxf3!o5|@M2N)eDYjEynU#5rQB>pMinl7ISe)QBfd> z#gdhunWxE81oG9b$%losL_0u6@qj$Z$igVZ$im3?pN)C)KcNV5uu0w^LzEa87&Mu{ znu-!Yj+@*d%)&Tfa;tET0Z5f5Gej;Pq&NXYfD9}G6TBd6ITKS-;&W3=Qj4Z+Ruoys z#27I7f|w+CDo82FfTFa?U&Qpd(?EO{5D_z3S6q<|p0`8wTV&QHp^#hqG`7+;*4oS&Cce2WXr zDJja#OTWccl3$XT12JK;zJw;*bWk8}iIT(c) z*%*Zw4 zr>7R_fn1temReM)$ynqM3R^}$O|BwCkjEzrN{Wa^gIES20u=H^i6BsB!WG$;phNlS`zm8Lv#8oG#NJ zhRNsUN<|kkFfcGNGB6agfpjslF|sfgEu4H`R(5ikJU_<*a8MR4oV-y^eDV@`L0%4! zQfOE#ntVv!Sz-~$-I5@eu@onkrD`%2%>apnO_o(q;^zS=2L}No8)MOu&3+0-jEuJ? zH!8MU_kpyqr{1re4Yapv@b(jt(dMWP@sPfBW8W^!tLa(-!ENm0&Z zXC-~cTa$~GlsV=>95`e0JS9ywL6B{Flg}#2C$cee{AXfh`p?G1$H?)Yjfsn~Xd75* zh^Evn?&SQOoYdr!%>2CKB9Lc_m_asYfryD90_3(LaMl62;g(=Yeo=CUUP@+hNl|7} zX-R5I(Z 0.1: + if name["supercategory"] == "Aves": + print(len(v), to_common([name["name"]])) + + # %% + + fc = sorted( + category_distances, key=lambda x: len(category_distances[x]), reverse=True + ) + for x in fc: + cc = train_dataset.orig_id_to_name[x] + if cc["supercategory"] == "Aves": + ou = to_common([cc["name"]]) + print(ou, len(category_distances[x])) + + +# %% diff --git a/model.py b/model.py index d5e71ba..47a77f0 100644 --- a/model.py +++ b/model.py @@ -1,11 +1,12 @@ # %% -from torchvision.models.detection import fasterrcnn_resnet50_fpn +import torchvision.models.detection + from torchvision.models.detection.faster_rcnn import FastRCNNPredictor -def Model(num_classes): - model = fasterrcnn_resnet50_fpn(pretrained=True) - num_classes = 2 # 1 class (person) + background + +def Model(num_classes, model_type=None): + chosen_model = torchvision.models.detection.__dict__[model_type] + model = chosen_model(pretrained=True) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model - diff --git a/models/20210701_201220.json b/models/20210701_201220.json new file mode 100644 index 0000000..a5fe817 --- /dev/null +++ b/models/20210701_201220.json @@ -0,0 +1,22 @@ +{ + "categories": [ + { + "supercategory": "Aves", + "id": 206, + "name": "Archilochus colubris", + "new_id": 1 + }, + { + "supercategory": "Aves", + "id": 4493, + "name": "Icterus galbula", + "new_id": 2 + }, + { + "supercategory": "Aves", + "id": 403, + "name": "Poecile atricapillus", + "new_id": 3 + } + ] +} \ No newline at end of file diff --git a/models/20210701_201841.json b/models/20210701_201841.json new file mode 100644 index 0000000..a616beb --- /dev/null +++ b/models/20210701_201841.json @@ -0,0 +1,23 @@ +{ + "categories": [ + { + "supercategory": "Aves", + "id": 206, + "name": "Archilochus colubris", + "new_id": 1 + }, + { + "supercategory": "Aves", + "id": 4493, + "name": "Icterus galbula", + "new_id": 2 + }, + { + "supercategory": "Aves", + "id": 403, + "name": "Poecile atricapillus", + "new_id": 3 + } + ], + "model_type": "fasterrcnn_mobilenet_v3_large_fpn" +} \ No newline at end of file diff --git a/models/20210701_201853.json b/models/20210701_201853.json new file mode 100644 index 0000000..a616beb --- /dev/null +++ b/models/20210701_201853.json @@ -0,0 +1,23 @@ +{ + "categories": [ + { + "supercategory": "Aves", + "id": 206, + "name": "Archilochus colubris", + "new_id": 1 + }, + { + "supercategory": "Aves", + "id": 4493, + "name": "Icterus galbula", + "new_id": 2 + }, + { + "supercategory": "Aves", + "id": 403, + "name": "Poecile atricapillus", + "new_id": 3 + } + ], + "model_type": "fasterrcnn_mobilenet_v3_large_fpn" +} \ No newline at end of file diff --git a/models/20210701_202607.json b/models/20210701_202607.json new file mode 100644 index 0000000..a616beb --- /dev/null +++ b/models/20210701_202607.json @@ -0,0 +1,23 @@ +{ + "categories": [ + { + "supercategory": "Aves", + "id": 206, + "name": "Archilochus colubris", + "new_id": 1 + }, + { + "supercategory": "Aves", + "id": 4493, + "name": "Icterus galbula", + "new_id": 2 + }, + { + "supercategory": "Aves", + "id": 403, + "name": "Poecile atricapillus", + "new_id": 3 + } + ], + "model_type": "fasterrcnn_mobilenet_v3_large_fpn" +} \ No newline at end of file diff --git a/train.py b/train.py index 981dd18..aa0d344 100644 --- a/train.py +++ b/train.py @@ -4,49 +4,79 @@ from model import Model from data import iNaturalistDataset import torch import os -import time +import datetime as dt +import json +import utils + +if not os.path.exists("models/"): + os.mkdir("models") + +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") + +model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S") +model_path = model_root + ".pth" +model_info = model_root + ".json" + + +species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"]) +model_type = "fasterrcnn_mobilenet_v3_large_fpn" -if not os.path.exists('models/'): - os.mkdirs('models') -device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') def run(): - val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True)) - train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False)) + val_dataset = iNaturalistDataset( + validation=True, + species=species_list, + ) + train_dataset = iNaturalistDataset( + train=True, + species=species_list, + ) + with open(model_info, "w") as js_p: + json.dump( + {"categories": train_dataset.categories, "model_type": model_type}, + js_p, + default=str, + indent=4, + ) train_data_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn - ) - val_data_loader = torch.utils.data.DataLoader( - val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn + train_dataset, + batch_size=8, + shuffle=True, + num_workers=4, + collate_fn=utils.collate_fn, ) - num_classes = 5 - model = Model(num_classes) + val_data_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=8, + shuffle=True, + num_workers=4, + collate_fn=utils.collate_fn, + ) + + num_classes = len(species_list) + 1 + model = Model(num_classes, model_type) model.to(device) params = [p for p in model.parameters() if p.requires_grad] - optimizer = torch.optim.SGD(params, lr=0.005, - momentum=0.9, weight_decay=0.0005) + optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, - step_size=3, - gamma=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) num_epochs = 10 for epoch in range(num_epochs): - print(epoch) - torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth') - # train for one epoch, printing every 10 iterations - engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10) - torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth') - # update the learning rate + train_one_epoch( + model, optimizer, train_data_loader, device, epoch, print_freq=10 + ) lr_scheduler.step() - torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth') - # evaluate on the test dataset - engine.evaluate(model, val_data_loader, device=device) + torch.save(model.state_dict(), model_path) + evaluate(model, val_data_loader, device=device) - -if __name__ == "__main__": - run() \ No newline at end of file + +if __name__ == "__main__": + run() diff --git a/transforms.py b/transforms.py index 8e4b887..a8a2886 100644 --- a/transforms.py +++ b/transforms.py @@ -28,8 +28,9 @@ class Compose(object): class RandomHorizontalFlip(T.RandomHorizontalFlip): - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if torch.rand(1) < self.p: image = F.hflip(image) if target is not None: @@ -45,15 +46,23 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip): class ToTensor(nn.Module): - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = F.to_tensor(image) return image, target class RandomIoUCrop(nn.Module): - def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, - max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): super().__init__() # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 self.min_scale = min_scale @@ -65,14 +74,19 @@ class RandomIoUCrop(nn.Module): self.options = sampler_options self.trials = trials - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if target is None: raise ValueError("The targets can't be None for this transform.") if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError( + "image should be 2/3 dimensional. Got {} dimensions.".format( + image.ndimension() + ) + ) elif image.ndimension() == 2: image = image.unsqueeze(0) @@ -82,7 +96,9 @@ class RandomIoUCrop(nn.Module): # sample an option idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) min_jaccard_overlap = self.options[idx] - if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option + if ( + min_jaccard_overlap >= 1.0 + ): # a value larger than 1 encodes the leave as-is option return image, target for _ in range(self.trials): @@ -106,14 +122,22 @@ class RandomIoUCrop(nn.Module): # check for any valid boxes with centers within the crop area cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2]) cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3]) - is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + is_within_crop_area = ( + (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + ) if not is_within_crop_area.any(): continue # check at least 1 box with jaccard limitations boxes = target["boxes"][is_within_crop_area] - ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]], - dtype=boxes.dtype, device=boxes.device)) + ious = torchvision.ops.boxes.box_iou( + boxes, + torch.tensor( + [[left, top, right, bottom]], + dtype=boxes.dtype, + device=boxes.device, + ), + ) if ious.max() < min_jaccard_overlap: continue @@ -130,14 +154,21 @@ class RandomIoUCrop(nn.Module): class RandomZoomOut(nn.Module): - def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5): + def __init__( + self, + fill: Optional[List[float]] = None, + side_range: Tuple[float, float] = (1.0, 4.0), + p: float = 0.5, + ): super().__init__() if fill is None: - fill = [0., 0., 0.] + fill = [0.0, 0.0, 0.0] self.fill = fill self.side_range = side_range - if side_range[0] < 1. or side_range[0] > side_range[1]: - raise ValueError("Invalid canvas side range provided {}.".format(side_range)) + if side_range[0] < 1.0 or side_range[0] > side_range[1]: + raise ValueError( + "Invalid canvas side range provided {}.".format(side_range) + ) self.p = p @torch.jit.unused @@ -146,11 +177,16 @@ class RandomZoomOut(nn.Module): # We fake the type to make it work on JIT return tuple(int(x) for x in self.fill) if is_pil else 0 - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError( + "image should be 2/3 dimensional. Got {} dimensions.".format( + image.ndimension() + ) + ) elif image.ndimension() == 2: image = image.unsqueeze(0) @@ -159,7 +195,9 @@ class RandomZoomOut(nn.Module): orig_w, orig_h = F._get_image_size(image) - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + r = self.side_range[0] + torch.rand(1) * ( + self.side_range[1] - self.side_range[0] + ) canvas_width = int(orig_w * r) canvas_height = int(orig_h * r) @@ -176,9 +214,12 @@ class RandomZoomOut(nn.Module): image = F.pad(image, [left, top, right, bottom], fill=fill) if isinstance(image, torch.Tensor): - v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) - image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \ - image[..., :, (left + orig_w):] = v + v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view( + -1, 1, 1 + ) + image[..., :top, :] = image[..., :, :left] = image[ + ..., (top + orig_h) :, : + ] = image[..., :, (left + orig_w) :] = v if target is not None: target["boxes"][:, 0::2] += left @@ -188,8 +229,14 @@ class RandomZoomOut(nn.Module): class RandomPhotometricDistort(nn.Module): - def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5), - hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5): + def __init__( + self, + contrast: Tuple[float] = (0.5, 1.5), + saturation: Tuple[float] = (0.5, 1.5), + hue: Tuple[float] = (-0.05, 0.05), + brightness: Tuple[float] = (0.875, 1.125), + p: float = 0.5, + ): super().__init__() self._brightness = T.ColorJitter(brightness=brightness) self._contrast = T.ColorJitter(contrast=contrast) @@ -197,11 +244,16 @@ class RandomPhotometricDistort(nn.Module): self._saturation = T.ColorJitter(saturation=saturation) self.p = p - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError( + "image should be 2/3 dimensional. Got {} dimensions.".format( + image.ndimension() + ) + ) elif image.ndimension() == 2: image = image.unsqueeze(0) diff --git a/utils.py b/utils.py index 3c52abb..88baf61 100644 --- a/utils.py +++ b/utils.py @@ -8,6 +8,8 @@ import torch import torch.distributed as dist + + class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. @@ -32,7 +34,7 @@ class SmoothedValue(object): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -67,7 +69,8 @@ class SmoothedValue(object): avg=self.avg, global_avg=self.global_avg, max=self.max, - value=self.value) + value=self.value, + ) def all_gather(data): @@ -130,15 +133,14 @@ class MetricLogger(object): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -151,31 +153,35 @@ class MetricLogger(object): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -185,22 +191,37 @@ class MetricLogger(object): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {} ({:.4f} s / it)'.format( - header, total_time_str, total_time / len(iterable))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) def collate_fn(batch): @@ -208,7 +229,6 @@ def collate_fn(batch): def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): - def f(x): if x >= warmup_iters: return 1 @@ -231,10 +251,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -271,25 +292,30 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) torch.distributed.barrier() setup_for_distributed(args.rank == 0)