Wednesday, June 15, 2016

Missing matched documents on searches and updates reproduction

This blog recently exposed an interesting concurrency caveat related to MongoDB where matching documents won't be found (or updated) if they are being reindexes.

The only part in the entry I was missing is a way how to reproduce this issue. So I decided to create a test which you can test against your version of MongoDB to check if it is still a problem.

Here it is:

package co.uk.matejtymes.mongodb;

import com.mongodb.*;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;

import static com.mongodb.BasicDBObjectBuilder.start;
import static java.util.Arrays.asList;
import static java.util.UUID.randomUUID;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertThat;

public class ReindexFailureTest {

    private static final String STATE_FIELD = "state";

    private static final Random RANDOM = new Random();

    private DBCollection coll;

    @Before
    public void setUp() throws Exception {
        // todo: provide connection details for your mongoDB instance
        MongoClient mongo = new MongoClient("localhost", 27017);
        DB db = mongo.getDB("testDb");

        coll = db.getCollection("indexTest");
    }

    @After
    public void tearDown() throws Exception {
        coll.drop();
    }

    @Test
    public void shouldFindAllMatchingItemsEvenWhenRecalculatingIndex()throws Exception {
        int docCount = 250;
        int concurrentUpdates = 40;
        int attemptsCount = 1_000;

        List<String> stateValues = asList("Active", "Inactive");

        coll.createIndex(new BasicDBObject(STATE_FIELD, 1));

        List<String> allIds = createNDocumentsWithState(docCount, stateValues);

        ExecutorService executor = newFixedThreadPool(concurrentUpdates);
        for (int attempt = 1; attempt <= attemptsCount; attempt++) {
            System.out.println(attempt + ". attempt");

            CountDownLatch beginLatch = new CountDownLatch(concurrentUpdates + 1);
            CountDownLatch endLatch = new CountDownLatch(concurrentUpdates + 1);

            for (int update = 0; update < concurrentUpdates; update++) {
                executor.submit(() -> updateState(pickRandomItem(allIds), stateValues, beginLatch, endLatch));
            }

            List<String> foundIds = findDocumentsInState(stateValues, beginLatch, endLatch);

            Set<String> uniqueIds = new HashSet<>();
            Set<String> duplicateIds = new HashSet<>();
            Set<String> missingIds = new HashSet<>(allIds);

            for (String foundId : foundIds) {
                if (uniqueIds.contains(foundId)) {
                    duplicateIds.add(foundId);
                }
                uniqueIds.add(foundId);
                missingIds.remove(foundId);
            }

            if (!missingIds.isEmpty()) {
                System.err.println(missingIds.size() + ". missingIds: " + missingIds);
            }
            if (!duplicateIds.isEmpty()) {
                System.err.println(duplicateIds.size() + ". duplicateIds: " + duplicateIds);
            }

            assertThat(foundIds, hasSize(allIds.size()));
            assertThat(missingIds, hasSize(0));
            assertThat(duplicateIds, hasSize(0));
        }

        executor.shutdown();
        executor.awaitTermination(3, SECONDS);
    }

    @Test
    public void shouldUpdateAllMatchingItemsEvenWhenRecalculatingIndex()throws Exception {
        int docCount = 250;
        int concurrentUpdates = 40;
        int attemptsCount = 1_000;

        List<String> stateValues = asList("Active", "Inactive");

        coll.createIndex(new BasicDBObject(STATE_FIELD, 1));

        List<String> allIds = createNDocumentsWithState(docCount, stateValues);

        ExecutorService executor = newFixedThreadPool(concurrentUpdates);
        for (int attempt = 1; attempt <= attemptsCount; attempt++) {
            System.out.println(attempt + ". attempt");

            String fieldToUpdate = "field" + attempt;
            Object valueToSet = true;

            CountDownLatch beginLatch = new CountDownLatch(concurrentUpdates + 1);
            CountDownLatch endLatch = new CountDownLatch(concurrentUpdates + 1);

            for (int update = 0; update < concurrentUpdates; update++) {
                executor.submit(() -> updateState(pickRandomItem(allIds), stateValues, beginLatch, endLatch));
            }

            BasicDBObject query = new BasicDBObject(STATE_FIELD, new BasicDBObject("$in", stateValues));
            BasicDBObject update = new BasicDBObject("$set", new BasicDBObject(fieldToUpdate, valueToSet));

            beginLatch.countDown();
            int n = coll.updateMulti(query, update).getN();
            endLatch.countDown();


            List<String> updatedIds = new ArrayList<>();
            coll.find(new BasicDBObject(fieldToUpdate, valueToSet)).forEach(
                    dbObject -> updatedIds.add((String) dbObject.get("_id"))
            );

            Set<String> missingIds = new HashSet<>(allIds);
            missingIds.removeAll(updatedIds);


            if (!missingIds.isEmpty()) {
                System.err.println(missingIds.size() + ". missingIds: " + missingIds);
            }
            if (n != allIds.size()) {
                System.err.println("n = " + n);
            }
            if (updatedIds.size() != allIds.size()) {
                System.err.println("updateIds = " + updatedIds.size());
            }

            assertThat(n, equalTo(allIds.size()));
            assertThat(updatedIds, hasSize(allIds.size()));
            assertThat(missingIds, hasSize(0));
        }

        executor.shutdown();
        executor.awaitTermination(3, SECONDS);
    }

    /* ====================== */
    /* --- helper methods --- */
    /* ====================== */

    private List<String> createNDocumentsWithState(int docCount, List<String> stateValues) {
        List<String> ids = new ArrayList<>();

        for (int i = 0; i < docCount; i++) {
            String id = randomUUID().toString();
            String state = stateValues.get(i % stateValues.size());

            DBObject dbObject = start()
                    .add("_id", id)
                    .add(STATE_FIELD, state)
                    .get();
            coll.insert(dbObject);

            ids.add(id);
        }
        return ids;
    }

    private void updateState(String id, List<String> stateValues, CountDownLatch beginLatch, CountDownLatch endLatch) {
        BasicDBObject query = new BasicDBObject("_id", id);

        String oldStateValue = (String) coll.find(query).next().get(STATE_FIELD);
        String newStateValue = stateValues.stream().filter(state -> !state.equals(oldStateValue)).findFirst().get();

        BasicDBObject update = new BasicDBObject("$set", new BasicDBObject(STATE_FIELD, newStateValue));

        beginLatch.countDown();
        coll.update(query, update);
        endLatch.countDown();
    }

    private List<String> findDocumentsInState(List<String> stateValues, CountDownLatch beginLatch, CountDownLatch endLatch) {
        BasicDBObject query = new BasicDBObject(STATE_FIELD, new BasicDBObject("$in", stateValues));
        Iterator<DBObject> dbObjects = coll.find(query).iterator();

        List<String> foundIds = new ArrayList<>();

        beginLatch.countDown();;
        while (dbObjects.hasNext()) {
            foundIds.add((String) dbObjects.next().get("_id"));
        }
        endLatch.countDown();

        return foundIds;
    }

    private static <T> T pickRandomItem(List<T> values) {
        return values.get(RANDOM.nextInt(values.size()));
    }
}