6/25/13

Testing thread-safety with JUnit

Here's a scenario showing how to test if your code is thread safe, in form of a JUnit integration test. In this example, it's a bank account (just acount number and account balance) and some logic that handles it. The application is a simple Spring/Hibernate/PostgreSQL app.

1. Application code:

  • BankAccount.java:
  •     public class BankAccount {
            Integer id;
            String number;
            Integer balance;
            (...)
        }
    
  • BankAccountDao.java:
  •     public interface BankAccountDao {
            BankAccount get(String number);
            void update(BankAccount bankAccount);
        }
    
  • BankAccountDaoImpl.java:
  •     public class BankAccountDaoImpl extends HibernateDaoSupport implements BankAccountDao {
            @Override
            public BankAccount get(String number) {
                return (BankAccount) DataAccessUtils.singleResult(
                    getHibernateTemplate().find("from BankAccount where number = ?", number));
            }
            @Override
            public void update(BankAccount bankAccount) {
                getHibernateTemplate().update(bankAccount);
            }
        }
    
  • BankService.java:
  •     public interface BankService {
            /**
             * @param accountNumber account number
             * @param amount        amount of money, positive or negative
             */
            void transfer(String accountNumber, Integer amount);
        }
    
  • BankServiceImpl.java:
  •     public class BankServiceImpl implements BankService {
            private BankAccountDao bankAccountDao;
            @Override
            public void transfer(String accountNumber, Integer amount) {
                BankAccount bankAccount = bankAccountDao.get(accountNumber);
                bankAccount.setBalance(bankAccount.getBalance() + amount);
                bankAccountDao.update(bankAccount);
            }
        }
    
There's also some XML configuration (Spring, Hibernate, transactions, etc.), but not relevant here. The transaction interceptor wraps the transfer() method.

2. The JUnit integration test

The code above can be quite easily tested with a simple Spring's JUnit test case. I initially copied over the code from this excellent blog post, and then did my own small modifications.
  • BankServiceTest.java
  • import org.junit.Assert;
    import org.junit.Test;
    import org.junit.runner.RunWith;
    import org.unitils.UnitilsJUnit4TestClassRunner;
    import org.unitils.dbunit.annotation.DataSet;
    import org.unitils.orm.hibernate.HibernateUnitils;
    import org.unitils.reflectionassert.ReflectionAssert;
    import org.unitils.spring.annotation.SpringApplicationContext;
    import org.unitils.spring.annotation.SpringBeanByType;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.concurrent.Callable;
    import java.util.concurrent.ExecutorService;
    import java.util.concurrent.Executors;
    import java.util.concurrent.Future;
    
    /**
     */
    @SpringApplicationContext({"classpath:/application-dao.xml", "classpath:/application-tx.xml", "classpath:/application-test-datasource.xml"})
    @DataSet("BankServiceTest.xml")
    @RunWith(UnitilsJUnit4TestClassRunner.class)
    public class BankServiceTest {
    
        @SpringBeanByType
        private BankService bankService;
    
        @SpringBeanByType
        private BankAccountDao bankAccountDao;
    
        private int threadCount = 200;
        private int amount = 1;
    
        @Test
        public void testUpdateBalance() throws Exception {
            Assert.assertEquals("The balance is 1000", 1000, bankAccountDao.get("10-1000").getBalance().intValue());
            ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
            List<Future<Void>> futures = new ArrayList<Future<Void>>();
            for (int x = 0; x < threadCount; x++) {
                Callable<Void> callable = new Callable<Void>() {
                    @Override
                    public Void call() throws Exception {
                        bankService.transfer("10-1000", amount);
                        return null;
                    }
                };
                Future<Void> submit = executorService.submit(callable);
                futures.add(submit);
            }
    
            List<Exception> exceptions = new ArrayList<Exception>();
            for (Future<Void> future : futures) {
                try {
                    future.get();
                } catch (Exception e) {
                    exceptions.add(e);
                    e.printStackTrace(System.err);
                }
            }
    
            executorService.shutdown();
    
            HibernateUnitils.getSession().clear();
            BankAccount bankAccount = bankAccountDao.get("10-1000");
            ReflectionAssert.assertReflectionEquals("No exceptions", new ArrayList<Exception>(), exceptions);
            Assert.assertEquals("Balance is 1000, again", 1200, bankAccount.getBalance().intValue());
        }
    }
        
    The initial account balance is 1000 USD. Then, we add 1 USD 200 times in parallel. Finally, the account balance should be 1200 USD. Here's step by step explanation:
    1. #26, #29 - Spring beans are injected into the test case - I used Unitils here,
    2. #37 - creating a pool of 200 threads,
    3. #43 - we invoke the transfer() method 200 times in parallel,
    4. #56 - collecting exceptions that may have raised, (we expect no exceptions to occur),
    5. #61 - thread pool is closed,
    6. #63 - clear Hibernate cache manually - so that we get the new balance, not the cached 1000,
    7. #65 - assertion - check if no exceptions occured,
    8. #66 - check if the account balance was correctly incremented to 1200,
    When running the code, it turns out that there is a problem with the code:
    java.lang.AssertionError: Balance is 1000, again expected:<1200> but was:<1046>
     at org.junit.Assert.fail(Assert.java:74)
     at org.junit.Assert.failNotEquals(Assert.java:448)
     at org.junit.Assert.assertEquals(Assert.java:102)
     at org.junit.Assert.assertEquals(Assert.java:323)
     at BankServiceTest.testUpdateBalance(BankServiceTest.java:66)
     at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
     at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
     at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        (...)
        

3. Code fix

We'll try to fix that with a pessmistic lock:
  • BankAccountDaoImpl.java
  •     public class BankAccountDaoImpl extends HibernateDaoSupport implements BankAccountDao {
            @Override
            public BankAccount get(String number) {
                return (BankAccount) DataAccessUtils.singleResult(getHibernateTemplate().find("from BankAccount where number = ?", number));
            }
            @Override
            public BankAccount getForUpdate(final String number) {
                return (BankAccount) DataAccessUtils.singleResult(getHibernateTemplate().executeFind(new HibernateCallback<List<BankAccount>>() {
                    @Override
                    public List<BankAccount> doInHibernate(Session session) throws HibernateException, SQLException {
                        return session.createQuery("from BankAccount ba where number = :number")
                            .setLockMode("ba", LockMode.PESSIMISTIC_WRITE)
                            .setString("number", number).list();
                        }
                }));
            }
            @Override
            public void update(BankAccount bankAccount) {
                getHibernateTemplate().update(bankAccount);
            }
        }
    
  • BankAccountDao.java:
  •     public interface BankAccountDao {
            BankAccount get(String number);
            BankAccount getForUpdate(final String number);
            void update(BankAccount bankAccount);
        }
    
  • BankServiceImpl.java:
  •     public class BankServiceImpl implements BankService {
            private BankAccountDao bankAccountDao;
            @Override
            public void transfer(String accountNumber, Integer amount) {
                BankAccount bankAccount = bankAccountDao.getForUpdate(accountNumber);
                bankAccount.setBalance(bankAccount.getBalance() + amount);
                bankAccountDao.update(bankAccount);
            }
        }
    

4. That's it!


Some reading:

No comments:

Post a Comment