Shape structure dataset
In [1]:
Copied!
import numpy as np
from matplotlib import pyplot as plt
from curvey.shape_structure_dataset import ShapeStructureDataset
import numpy as np
from matplotlib import pyplot as plt
from curvey.shape_structure_dataset import ShapeStructureDataset
The 2D Shape Structure Dataset makes for a useful source of curves for testing shape interpolation.
Download the shapes zip file here. It does not need to be unzipped.
In [2]:
Copied!
dataset = ShapeStructureDataset("~/Downloads/ShapesJSON.zip")
dataset = ShapeStructureDataset("~/Downloads/ShapesJSON.zip")
All classes in the dataset:
In [3]:
Copied!
print(", ".join(dataset.class_names))
print(", ".join(dataset.class_names))
apple, bat, beetle, bell, bird, bone, bottle, brick, butterfly, camel, car, carriage, cattle, cellular_phone, chicken, children, chopper, classic, comma, crown, cup, deer, device, dino, dog, elephant, face, fish, flatfish, fly, fork, fountain, frog, glas, guitar, hammer, hat, hcircle, heart, horse, horseshoe, image, jar, key, lizzard, lmfish, misk, octopus, pencil, personal_car, pocket, rat, ray, sea_snake, shoe, spoon, spring, stef, teddy, tree, turtle, watch
For each class, there are (usually) multiple exemplars:
In [4]:
Copied!
print(", ".join(dataset.names_by_class["apple"]))
print(", ".join(dataset.names_by_class["apple"]))
apple-1, apple-10, apple-11, apple-12, apple-13, apple-14, apple-15, apple-16, apple-17, apple-18, apple-19, apple-2, apple-20, apple-3, apple-4, apple-5, apple-6, apple-7, apple-8, apple-9
Load a curve either by its full name or its class and index:
In [5]:
Copied!
dataset.load_curve("apple", 6)
dataset.load_curve("apple", 6)
Out[5]:
Curve(n=103)
Plot the first exemplar in each class:
In [6]:
Copied!
n = len(dataset.class_names)
n_cols = 10
n_rows = int(np.ceil(n / n_cols))
fig, ax = plt.subplots(figsize=(12, 1.2 * n_rows))
for i, class_name in enumerate(dataset.class_names):
c = dataset.load_curve(class_name, 0).translate("center")
max_r = np.linalg.norm(c.points, axis=1).max()
c = c.scale(0.35 / max_r)
y = i // n_cols # This is correct
x = i - y * n_cols
c.translate([x, -y]).plot(color="black")
plt.text(x, 0.4 - y, class_name, horizontalalignment="center")
ax.axis("equal")
ax.axis("off")
ax.set_ylim((-(n_rows - 0.5), 0.75));
n = len(dataset.class_names)
n_cols = 10
n_rows = int(np.ceil(n / n_cols))
fig, ax = plt.subplots(figsize=(12, 1.2 * n_rows))
for i, class_name in enumerate(dataset.class_names):
c = dataset.load_curve(class_name, 0).translate("center")
max_r = np.linalg.norm(c.points, axis=1).max()
c = c.scale(0.35 / max_r)
y = i // n_cols # This is correct
x = i - y * n_cols
c.translate([x, -y]).plot(color="black")
plt.text(x, 0.4 - y, class_name, horizontalalignment="center")
ax.axis("equal")
ax.axis("off")
ax.set_ylim((-(n_rows - 0.5), 0.75));