Sunday, December 19, 2021

Stock Charts Detection Using Image Classification Model ResNet

Stock Charts Detection Using Image Classification Model ResNet
Intro

This tutorial explores image classification in PyTorch using state-of-the-art computer vision models. The dataset used in this tutorial will have 3 classes that are very imbalanced. So, we will explore augmentation as a solution to the imbalance problem.

Data used in this notebook can be found at https://www.nbshare.io/blog/datasets/

Contents:
  1. Data loading
    • Loading labels
    • Train-test splitting
    • Augmentation
    • Creating Datasets
    • Random Weighted Sampling and DataLoaders
  2. CNN building and fine-tuning ResNet
    • CNN
    • ResNet
  3. Setup and training
  4. Evaluation
  5. Testing
Data Loading
In [1]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import datasets, models
from torchvision import transforms
import matplotlib.pyplot as plt

Setting the device to make use of the GPU.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
Out[2]:
device(type='cuda')

Identifying the data paths.

In [4]:
data_dir = "images/"
labels_file = "images_labeled.csv"
Loading Labels

Since the labels are in a CSV file, we use

(continued...)

from Planet SciPy
read more

No comments:

Post a Comment

TestDriven.io: Working with Static and Media Files in Django

This article looks at how to work with static and media files in a Django project, locally and in production. from Planet Python via read...