import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW

# Set the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and model
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=NUM_CLASSES)

# Set the model to device
model.to(device)

# Define your dataset and dataloader
train_dataset = YourDataset(...)  # Replace with your own dataset
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_GAMMA)

# Training loop
model.train()
for epoch in range(NUM_EPOCHS):
    for batch in train_dataloader:
        # Prepare input data
        inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
        labels = batch["label"].to(device)

        # Move inputs to device
        inputs = {key: value.to(device) for key, value in inputs.items()}

        # Forward pass
        outputs = model(**inputs)
        logits = outputs.logits

        # Compute loss
        loss = torch.nn.functional.cross_entropy(logits, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Adjust learning rate
    scheduler.step()

# Save the trained model
model.save_pretrained("path/to/save/model")
tokenizer.save_pretrained("path/to/save/tokenizer")