项目作者: lRomul

项目描述 :
Lightweight library for training neural networks in PyTorch
高级语言: Python
项目地址: git://github.com/lRomul/argus.git
创建时间: 2018-07-05T09:43:08Z
项目社区:https://github.com/lRomul/argus

开源协议:MIT License

下载




argus-logo

PyPI version
Documentation Status
Test
CodeFactor
codecov
Downloads

Argus is a lightweight library for training neural networks in PyTorch.

Documentation

https://pytorch-argus.readthedocs.io

Installation

Requirements:

  • torch>=2.0.0

From pip:

  1. pip install pytorch-argus

From source:

  1. pip install -U git+https://github.com/lRomul/argus.git@dev

Example

Simple image classification example with create_model from pytorch-image-models:

  1. from torch.utils.data import DataLoader
  2. from torchvision.datasets import MNIST
  3. from torchvision.transforms import Compose, ToTensor, Normalize
  4. import timm
  5. import argus
  6. from argus.callbacks import MonitorCheckpoint, EarlyStopping, ReduceLROnPlateau
  7. def get_data_loaders(batch_size):
  8. data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
  9. train_mnist_dataset = MNIST(download=True, root="mnist_data",
  10. transform=data_transform, train=True)
  11. val_mnist_dataset = MNIST(download=False, root="mnist_data",
  12. transform=data_transform, train=False)
  13. train_loader = DataLoader(train_mnist_dataset,
  14. batch_size=batch_size, shuffle=True)
  15. val_loader = DataLoader(val_mnist_dataset,
  16. batch_size=batch_size * 2, shuffle=False)
  17. return train_loader, val_loader
  18. class TimmModel(argus.Model):
  19. nn_module = timm.create_model
  20. if __name__ == "__main__":
  21. train_loader, val_loader = get_data_loaders(batch_size=256)
  22. params = {
  23. 'nn_module': {
  24. 'model_name': 'tf_efficientnet_b0_ns',
  25. 'pretrained': False,
  26. 'num_classes': 10,
  27. 'in_chans': 1,
  28. 'drop_rate': 0.2,
  29. 'drop_path_rate': 0.2
  30. },
  31. 'optimizer': ('Adam', {'lr': 0.01}),
  32. 'loss': 'CrossEntropyLoss',
  33. 'device': 'cuda'
  34. }
  35. model = TimmModel(params)
  36. callbacks = [
  37. MonitorCheckpoint(dir_path='mnist', monitor='val_accuracy', max_saves=3),
  38. EarlyStopping(monitor='val_accuracy', patience=9),
  39. ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=3)
  40. ]
  41. model.fit(train_loader,
  42. val_loader=val_loader,
  43. num_epochs=50,
  44. metrics=['accuracy'],
  45. callbacks=callbacks,
  46. metrics_on_train=True)

More examples you can find here.
Additional guides on how to customize and use argus component can be found in Guides section.

Why this name, Argus?

The library name is a reference to a planet from World of Warcraft.
Argus is the original homeworld of the eredar (a race of supremely talented magic-wielders), now located within the Twisting Nether.
It was once described as a utopian world whose inhabitants were both vastly intelligent and highly gifted in magic.
It has since been twisted by demonic, chaotic energies and became the stronghold and homeworld of the Burning Legion.