yacwc
This commit is contained in:
3
data.py
3
data.py
@@ -15,7 +15,8 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
||||
if sys.platform == "win32":
|
||||
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
|
||||
else:
|
||||
raise NotImplementedError("Not defined for this platform")
|
||||
PATH_ROOT = '/home/thebears/data/ml/inaturalist'
|
||||
|
||||
|
||||
|
||||
def get_transform(train):
|
||||
|
||||
4
train.py
4
train.py
@@ -45,7 +45,7 @@ def run():
|
||||
|
||||
train_data_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=8,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
collate_fn=utils.collate_fn,
|
||||
@@ -53,7 +53,7 @@ def run():
|
||||
|
||||
val_data_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=8,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
collate_fn=utils.collate_fn,
|
||||
|
||||
Reference in New Issue
Block a user