diff --git a/flaskr/db.py b/flaskr/db.py index fb85bfd..457ff10 100644 --- a/flaskr/db.py +++ b/flaskr/db.py @@ -7,11 +7,15 @@ from flask_sqlalchemy import SQLAlchemy db = SQLAlchemy() + +def init_db(): + db.create_all() + @click.command('init-db') @with_appcontext def init_db_command(): """Clear the existing data and create new tables.""" - db.create_all() + init_db() click.echo('Initialized the database.') def init_app(app): diff --git a/tests/conftest.py b/tests/conftest.py index cb2a7ec..eeb2419 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,12 @@ import os import tempfile +from datetime import datetime import pytest +from werkzeug.security import generate_password_hash + from flaskr import create_app -from flaskr.db import get_db, init_db +from flaskr.db import init_db, db, User, Post with open(os.path.join(os.path.dirname(__file__), 'data.sql'), 'rb') as f: _data_sql = f.read().decode('utf8') @@ -15,12 +18,19 @@ def app(): app = create_app({ 'TESTING': True, - 'DATABASE': db_path, + 'SECRET_KEY': 'testing', + 'SQLALCHEMY_DATABASE_URI': 'sqlite:///{}'.format(db_path) }) with app.app_context(): init_db() - get_db().executescript(_data_sql) + db.session.add(User(username='test', + password=generate_password_hash('test'))) + db.session.add(Post(title='test title', + body='test\nbody', + user_id=1, + created=datetime(year=2018,month=1,day=1))) + db.session.commit() yield app diff --git a/tests/data.sql b/tests/data.sql deleted file mode 100644 index 9b68006..0000000 --- a/tests/data.sql +++ /dev/null @@ -1,8 +0,0 @@ -INSERT INTO user (username, password) -VALUES - ('test', 'pbkdf2:sha256:50000$TCI4GzcX$0de171a4f4dac32e3364c7ddc7c14f3e2fa61f2d17574483f7ffbb431b4acb2f'), - ('other', 'pbkdf2:sha256:50000$kJPKsz6N$d2d4784f1b030a9761f5ccaeeaca413f27f2ecb76d6168407af962ddce849f79'); - -INSERT INTO post (title, body, author_id, created) -VALUES - ('test title', 'test' || x'0a' || 'body', 1, '2018-01-01 00:00:00'); diff --git a/tests/test_auth.py b/tests/test_auth.py index 8165553..32c7ae6 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,6 @@ import pytest from flask import g, session -from flaskr.db import get_db +from flaskr.db import User def test_register(client, app): @@ -11,9 +11,7 @@ def test_register(client, app): assert 'http://localhost/auth/login' == response.headers['Location'] with app.app_context(): - assert get_db().execute( - "select * from user where username = 'a'", - ).fetchone() is not None + assert User.query.all() is not None @pytest.mark.parametrize(('username', 'password', 'message'), ( @@ -36,7 +34,7 @@ def test_login(client, auth): with client: client.get('/') assert session['user_id'] == 1 - assert g.user['username'] == 'test' + assert g.user.username == 'test' @pytest.mark.parametrize(('username', 'password', 'message'), ( diff --git a/tests/test_blog.py b/tests/test_blog.py index 8178557..473e05e 100644 --- a/tests/test_blog.py +++ b/tests/test_blog.py @@ -1,5 +1,5 @@ import pytest -from flaskr.db import get_db +from flaskr.db import db, Post def test_index(client, auth): @@ -28,9 +28,10 @@ def test_login_required(client, path): def test_author_required(app, client, auth): # change the post author to another user with app.app_context(): - db = get_db() - db.execute('UPDATE post SET author_id = 2 WHERE id = 1') - db.commit() + post = Post.query.filter_by(user_id=1).first() + post.user_id = 2 + db.session.add(post) + db.session.commit() auth.login() # current user can't modify other user's post @@ -54,8 +55,7 @@ def test_create(client, auth, app): client.post('/create', data={'title': 'created', 'body': ''}) with app.app_context(): - db = get_db() - count = db.execute('SELECT COUNT(id) FROM post').fetchone()[0] + count = Post.query.count() assert count == 2 @@ -65,9 +65,8 @@ def test_update(client, auth, app): client.post('/1/update', data={'title': 'updated', 'body': ''}) with app.app_context(): - db = get_db() - post = db.execute('SELECT * FROM post WHERE id = 1').fetchone() - assert post['title'] == 'updated' + post = Post.query.filter_by(id=1).first() + assert post.title == 'updated' @pytest.mark.parametrize('path', ( @@ -85,6 +84,5 @@ def test_delete(client, auth, app): assert response.headers['Location'] == 'http://localhost/' with app.app_context(): - db = get_db() - post = db.execute('SELECT * FROM post WHERE id = 1').fetchone() + post = Post.query.filter_by(id=1).first() assert post is None diff --git a/tests/test_db.py b/tests/test_db.py index 0132ba2..0492387 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,18 +1,8 @@ import sqlite3 import pytest -from flaskr.db import get_db - -def test_get_close_db(app): - with app.app_context(): - db = get_db() - assert db is get_db() - - with pytest.raises(sqlite3.ProgrammingError) as e: - db.execute('SELECT 1') - - assert 'closed' in str(e) +from flaskr.db import Post, User def test_init_db_command(runner, monkeypatch): class Recorder(object): @@ -25,3 +15,10 @@ def test_init_db_command(runner, monkeypatch): result = runner.invoke(args=['init-db']) assert 'Initialized' in result.output assert Recorder.called + +def test_repr(): + user = User(username='test') + assert 'test' in repr(user) + + post = Post(title='test') + assert 'test' in repr(post)