yacwc
This commit is contained in:
3
data.py
3
data.py
@@ -15,7 +15,8 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
|
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Not defined for this platform")
|
PATH_ROOT = '/home/thebears/data/ml/inaturalist'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_transform(train):
|
def get_transform(train):
|
||||||
|
|||||||
4
train.py
4
train.py
@@ -45,7 +45,7 @@ def run():
|
|||||||
|
|
||||||
train_data_loader = torch.utils.data.DataLoader(
|
train_data_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=8,
|
batch_size=16,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
collate_fn=utils.collate_fn,
|
collate_fn=utils.collate_fn,
|
||||||
@@ -53,7 +53,7 @@ def run():
|
|||||||
|
|
||||||
val_data_loader = torch.utils.data.DataLoader(
|
val_data_loader = torch.utils.data.DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=8,
|
batch_size=16,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
collate_fn=utils.collate_fn,
|
collate_fn=utils.collate_fn,
|
||||||
|
|||||||
Reference in New Issue
Block a user