git仓库:https://github.com/FoundationVision/LlamaGen
数据集准备
如果用ImageFolder读取,则最好和ImageNet一致。
data_path/
class_1/
image_001.jpg
image_002.jpg
...
class_2/
image_003.jpg
image_004.jpg
...
...
class_n/
image_005.jpg
image_006.jpg
...
则
python">def build_imagenet(args, transform):
return ImageFolder(args.data_path, transform=transform)
如果是train,val,test,最好整理成
data_path/
train/
class_1/
image_001.jpg
image_002.jpg
...
class_2/
image_003.jpg
image_004.jpg
...
...
val/
class_1/
image_005.jpg
image_006.jpg
...
class_2/
image_007.jpg
image_008.jpg
...
...
test/
class_1/
image_009.jpg
image_010.jpg
...
class_2/
image_011.jpg
image_012.jpg
...
...
读取:
python">train_dataset = datasets.ImageFolder(root=args.data_path + '/train', transform=transform)
# 加载验证集
val_dataset = datasets.ImageFolder(root=args.data_path + '/val', transform=transform)
# 加载测试集
test_dataset = datasets.ImageFolder(root=args.data_path + '/test', transform=transform)
数据集预处理
NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=3 torchrun \
--nnodes=1 --nproc_per_node=1 --node_rank=0 \
--master_addr=localhost \
autoregressive/train/extract_codes_c2i.py \
--vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \
--data-path 你的数据集 \
--code-path VQGAN处理的数据集放在哪 \
--ten-crop \
--crop-range 1.1 \
--image-size 256
这里改成自己数据集的长度
ten-crop是作者定义的一种数据增强,每一个图片生成10个crop。最好修改一下这里的代码,训练的时候仅仅取一个。
注释掉这个self.flip
训练
NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=4,5 torchrun \
--nnodes=1 --nproc_per_node=2 --node_rank=0 \
--master_addr=localhost \
--master_port=8902 \
./autoregressive/train/train_c2i.py \
--cloud-save-path xxx \
--code-path 之前放VQGAN处理后数据集的地方 \
--image-size 256 \
--gpt-model GPT-B
生成
修改类别,权重
parser.add_argument("--num-classes", type=int, default=xxx)
label定义:
我的生成结果(数据集用了TinyImageNet的8个类)
300step
1500step