Skip to content
Snippets Groups Projects
Unverified Commit 68d71c7e authored by Sebastian Höffner's avatar Sebastian Höffner
Browse files

Initial commit.

parents
No related branches found
No related tags found
No related merge requests found
artifacts
!artifacts/.keep
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
### Python Patch ###
.venv/
# End of https://www.gitignore.io/api/python
LICENSE 0 → 100644
MIT License
Copyright (c) 2019 Sebastian Höffner
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Pipfile 0 → 100644
[[source]]
name = "pypi"
url = "https://pypi.org/simple"
verify_ssl = true
[dev-packages]
[packages]
tensorflow = "*"
numpy = "*"
pandas = "*"
keras = "*"
gitpython = "*"
pymongo = "*"
sacred = "*"
matplotlib = "*"
h5py = "*"
This diff is collapsed.
# Learning and generating table settings
## Running training
This project uses [sacred](https://github.com/IDSIA/sacred), so running can be done by using the "usual" sacred commands.
```bash
python tablesetting.py print_config # To see what can be configured
python tablesetting.py # Just train
```
### Persisting run results using Sacred, Mongo, and Docker
If you have docker and docker-compose, you can spin up MongoDB and Omniboard instances using the `docker-compose.yml` file.
You can then train and store the results to the MongoDB, and inspect the results at your [local omniboard (http://localhost:9000)](http://localhost:9000).
```bash
docker-compose up
# in new terminal
python tablesetting.py -m sacred
```
## Acknowledgments
The research reported in this repository has been supported by the German Research Foundation DFG, as part of Collaborative Research Center (Sonderforschungsbereich) 1320 "EASE – Everyday Activity Science and Engineering", University of Bremen (https://www.ease-crc.org/). The research was conducted in subproject P01 "Embodied semantics for the language of action and change: Combining analysis, reasoning and simulation".
# `settings.npy`
Contains individual (one-person) complete settings containing a cup, a fork, and a knife each.
The settings are rotated so that they are roughly aligned according to a person who would try to eat from the seat, with the plate at `(0, 0)`.
Data format:
- `cup_x`, `cup_y`, `fork_x`, `fork_y`, `knife_x`, `knife_y`
# `regular_items.npy`
Contains the items from `settings.npy` (i.e. the same "normalization" occurred), but each row is one item., compared to one row per setting.
Thus, this file contains three times as many entries as `settings.npy`.
Data format:
- `x`, `y`, `label`
With labels:
- 0: cup
- 1: fork
- 2: knife
File added
File added
version: '3.1'
services:
omniboard:
image: vivekratnavel/omniboard
command: ['--mu', 'mongodb://mongo:27017/sacred']
ports:
- 9000:9000
depends_on:
- mongo
mongo:
image: mongo
restart: always
volumes:
- data:/data/db
ports:
- 27017:27017
environment:
MONGO_INITDB_DATABASE: sacred
volumes:
data:
import numpy as np
from sacred import Ingredient
dataset_ingredient = Ingredient('dataset')
@dataset_ingredient.config
def config():
# The file to load
filename = 'data/settings.npy'
filename
@dataset_ingredient.capture
def load_dataset(filename):
return np.load(filename)
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from sacred import Experiment
from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import Adam
from ingredients.dataset import dataset_ingredient, load_dataset
experiment = Experiment(name='GAN', ingredients=[dataset_ingredient, ])
class GAN:
def __init__(self):
"""Initializes and compiles the GAN model."""
self.latent_size = 100
self.compile()
@property
def generator(self):
if not hasattr(self, '_generator'):
inputs = Input(shape=(self.latent_size, ))
hidden = Dense(128, activation='relu')(inputs)
hidden = Dense(128, activation='relu')(hidden)
outputs = Dense(6, activation='tanh')(hidden)
self._generator = Model(inputs=inputs, outputs=outputs, name='Generator')
return self._generator
@property
def discriminator(self):
if not hasattr(self, '_discriminator'):
inputs = Input(shape=(6, ))
hidden = Dense(128, activation='relu')(inputs)
hidden = Dense(128, activation='relu')(hidden)
outputs = Dense(1, activation='sigmoid')(hidden)
self._discriminator = Model(inputs=inputs, outputs=outputs, name='Discriminator')
return self._discriminator
def compile(self):
optimizer = Adam(0.00002, 0.5)
# First, compile discriminator and set it to non-trainable
self.discriminator.compile(optimizer=optimizer,
loss='binary_crossentropy')
self.discriminator.trainable = False
# Then, create gan with non-trainable discriminator and compile it
z = Input(shape=(self.latent_size, ))
outputs = self.discriminator(self.generator(z))
self.gan = Model(inputs=z, outputs=outputs, name=self.__class__.__name__)
self.gan.compile(optimizer=optimizer,
loss='binary_crossentropy')
def step(self, data, batch_size):
"""Performs one training step, i.e. one training batch.
Should be replaced with epoch learning."""
batch = data[np.random.randint(0, len(data), batch_size)]
z = np.random.randn(batch_size, self.latent_size)
generated = self.generator.predict(z)
ld = self.discriminator.train_on_batch(np.vstack((generated, batch)),
np.hstack((np.zeros((batch_size,)), np.ones((batch_size,)))))
noise = np.random.randn(batch_size, self.latent_size)
generated = self.generator.predict(z)
lg = self.gan.train_on_batch(noise, 1 - self.discriminator.predict(generated))
return ld, lg
@experiment.config
def config():
# The batch size
batch_size = 25
# The number of training steps (not real epochs at the moment)
epochs = 50000
# Artifact directory
artifacts_path = 'artifacts'
# For flake8, ignore W0612 "assigned but never used" by using variables
batch_size
epochs
artifacts_path
@experiment.capture
def artifact_path(filename, artifacts_path):
"""Appends the file name to the artifacts_path."""
return Path(artifacts_path) / filename
def save_model(filename, model):
"""Saves the model."""
path = artifact_path(filename)
path.write_text(model.to_json())
experiment.add_artifact(path, name=model.name)
def save_weights(filename, model):
"""Saves model weights."""
path = artifact_path(filename)
model.save_weights(path)
experiment.add_artifact(path)
def plot_samples(filename, samples, title=None):
samples = samples.reshape(-1, 2)
C = ['red', 'blue', 'green', 'orange']
colors = C[:3] * (samples.shape[0] // 3)
fig = plt.figure()
ax = fig.add_subplot(111)
if title:
ax.set_title(title)
ax.scatter(*zip(*samples), c=colors, alpha=0.3)
ax.scatter(0, 0, c=C[-1])
ax.legend([mpatches.Rectangle((0, 0), 1, 1, fc=c) for c in C],
['cup', 'fork', 'knife', 'plate'])
# save
path = artifact_path(filename)
fig.savefig(path)
experiment.add_artifact(path)
return fig
@experiment.automain
@experiment.command
def train(_run, _log, dataset, batch_size, epochs):
data = load_dataset(dataset['filename'])
experiment.add_resource(dataset['filename'])
plot_samples('dataset.png', data, 'Original')
gan = GAN()
_log.info('GAN:')
gan.gan.summary(print_fn=_log.info)
_log.info('Generator:')
gan.generator.summary(print_fn=_log.info)
_log.info('Discriminator:')
gan.discriminator.summary(print_fn=_log.info)
save_model('gan.json', gan.gan)
test_z = np.random.randn(100, gan.latent_size)
for epoch in range(1, epochs + 1):
ld, lg = gan.step(data, batch_size) # TODO: replace with real epoch
_run.log_scalar('loss.generator', lg, epoch)
_run.log_scalar('loss.discriminator', ld, epoch)
if not epoch % 1000:
_log.info(f'Epoch {epoch} - Losses: G {lg:.4f}, D {ld:.4f}')
if not epoch % 10000:
plot_samples(f'generated_{epoch}.png', gan.generator.predict(test_z), f'Generated {epoch}')
save_weights('gan.h5', gan.gan)
plot_samples(f'generated.png', gan.generator.predict(test_z), f'Generated {epoch}')
return lg, ld
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment