1
0
Fork 0
mirror of https://github.com/shouptech/flask-tutorial.git synced 2026-02-03 07:29:42 +00:00

Update unit tests for sqlalchemy

This commit is contained in:
Emma 2018-11-03 23:45:12 -06:00
parent c3a3eb30e2
commit 4622b87a6f
6 changed files with 38 additions and 39 deletions

View file

@ -7,11 +7,15 @@ from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy() db = SQLAlchemy()
def init_db():
db.create_all()
@click.command('init-db') @click.command('init-db')
@with_appcontext @with_appcontext
def init_db_command(): def init_db_command():
"""Clear the existing data and create new tables.""" """Clear the existing data and create new tables."""
db.create_all() init_db()
click.echo('Initialized the database.') click.echo('Initialized the database.')
def init_app(app): def init_app(app):

View file

@ -1,9 +1,12 @@
import os import os
import tempfile import tempfile
from datetime import datetime
import pytest import pytest
from werkzeug.security import generate_password_hash
from flaskr import create_app from flaskr import create_app
from flaskr.db import get_db, init_db from flaskr.db import init_db, db, User, Post
with open(os.path.join(os.path.dirname(__file__), 'data.sql'), 'rb') as f: with open(os.path.join(os.path.dirname(__file__), 'data.sql'), 'rb') as f:
_data_sql = f.read().decode('utf8') _data_sql = f.read().decode('utf8')
@ -15,12 +18,19 @@ def app():
app = create_app({ app = create_app({
'TESTING': True, 'TESTING': True,
'DATABASE': db_path, 'SECRET_KEY': 'testing',
'SQLALCHEMY_DATABASE_URI': 'sqlite:///{}'.format(db_path)
}) })
with app.app_context(): with app.app_context():
init_db() init_db()
get_db().executescript(_data_sql) db.session.add(User(username='test',
password=generate_password_hash('test')))
db.session.add(Post(title='test title',
body='test\nbody',
user_id=1,
created=datetime(year=2018,month=1,day=1)))
db.session.commit()
yield app yield app

View file

@ -1,8 +0,0 @@
INSERT INTO user (username, password)
VALUES
('test', 'pbkdf2:sha256:50000$TCI4GzcX$0de171a4f4dac32e3364c7ddc7c14f3e2fa61f2d17574483f7ffbb431b4acb2f'),
('other', 'pbkdf2:sha256:50000$kJPKsz6N$d2d4784f1b030a9761f5ccaeeaca413f27f2ecb76d6168407af962ddce849f79');
INSERT INTO post (title, body, author_id, created)
VALUES
('test title', 'test' || x'0a' || 'body', 1, '2018-01-01 00:00:00');

View file

@ -1,6 +1,6 @@
import pytest import pytest
from flask import g, session from flask import g, session
from flaskr.db import get_db from flaskr.db import User
def test_register(client, app): def test_register(client, app):
@ -11,9 +11,7 @@ def test_register(client, app):
assert 'http://localhost/auth/login' == response.headers['Location'] assert 'http://localhost/auth/login' == response.headers['Location']
with app.app_context(): with app.app_context():
assert get_db().execute( assert User.query.all() is not None
"select * from user where username = 'a'",
).fetchone() is not None
@pytest.mark.parametrize(('username', 'password', 'message'), ( @pytest.mark.parametrize(('username', 'password', 'message'), (
@ -36,7 +34,7 @@ def test_login(client, auth):
with client: with client:
client.get('/') client.get('/')
assert session['user_id'] == 1 assert session['user_id'] == 1
assert g.user['username'] == 'test' assert g.user.username == 'test'
@pytest.mark.parametrize(('username', 'password', 'message'), ( @pytest.mark.parametrize(('username', 'password', 'message'), (

View file

@ -1,5 +1,5 @@
import pytest import pytest
from flaskr.db import get_db from flaskr.db import db, Post
def test_index(client, auth): def test_index(client, auth):
@ -28,9 +28,10 @@ def test_login_required(client, path):
def test_author_required(app, client, auth): def test_author_required(app, client, auth):
# change the post author to another user # change the post author to another user
with app.app_context(): with app.app_context():
db = get_db() post = Post.query.filter_by(user_id=1).first()
db.execute('UPDATE post SET author_id = 2 WHERE id = 1') post.user_id = 2
db.commit() db.session.add(post)
db.session.commit()
auth.login() auth.login()
# current user can't modify other user's post # current user can't modify other user's post
@ -54,8 +55,7 @@ def test_create(client, auth, app):
client.post('/create', data={'title': 'created', 'body': ''}) client.post('/create', data={'title': 'created', 'body': ''})
with app.app_context(): with app.app_context():
db = get_db() count = Post.query.count()
count = db.execute('SELECT COUNT(id) FROM post').fetchone()[0]
assert count == 2 assert count == 2
@ -65,9 +65,8 @@ def test_update(client, auth, app):
client.post('/1/update', data={'title': 'updated', 'body': ''}) client.post('/1/update', data={'title': 'updated', 'body': ''})
with app.app_context(): with app.app_context():
db = get_db() post = Post.query.filter_by(id=1).first()
post = db.execute('SELECT * FROM post WHERE id = 1').fetchone() assert post.title == 'updated'
assert post['title'] == 'updated'
@pytest.mark.parametrize('path', ( @pytest.mark.parametrize('path', (
@ -85,6 +84,5 @@ def test_delete(client, auth, app):
assert response.headers['Location'] == 'http://localhost/' assert response.headers['Location'] == 'http://localhost/'
with app.app_context(): with app.app_context():
db = get_db() post = Post.query.filter_by(id=1).first()
post = db.execute('SELECT * FROM post WHERE id = 1').fetchone()
assert post is None assert post is None

View file

@ -1,18 +1,8 @@
import sqlite3 import sqlite3
import pytest import pytest
from flaskr.db import get_db
from flaskr.db import Post, User
def test_get_close_db(app):
with app.app_context():
db = get_db()
assert db is get_db()
with pytest.raises(sqlite3.ProgrammingError) as e:
db.execute('SELECT 1')
assert 'closed' in str(e)
def test_init_db_command(runner, monkeypatch): def test_init_db_command(runner, monkeypatch):
class Recorder(object): class Recorder(object):
@ -25,3 +15,10 @@ def test_init_db_command(runner, monkeypatch):
result = runner.invoke(args=['init-db']) result = runner.invoke(args=['init-db'])
assert 'Initialized' in result.output assert 'Initialized' in result.output
assert Recorder.called assert Recorder.called
def test_repr():
user = User(username='test')
assert 'test' in repr(user)
post = Post(title='test')
assert 'test' in repr(post)