From 7adc80c85d8815b4e2f5b7454c31bedc7f04de34 Mon Sep 17 00:00:00 2001
From: Yu Zhao <yuzhao@google.com>
Date: Mon, 5 Apr 2021 04:17:41 -0600
Subject: [PATCH] BACKPORT: FROMLIST: mm: multigenerational lru:
 mm_struct list

In order to scan page tables, we add an infrastructure to maintain
either a system-wide mm_struct list or per-memcg mm_struct lists, and
track whether an mm_struct is being used or has been used since the
last scan.

Multiple threads can concurrently work on the same mm_struct list, and
each of them will be given a different mm_struct belonging to a
process that has been scheduled since the last scan.

Signed-off-by: Yu Zhao <yuzhao@google.com>
Tested-by: Konstantin Kharlamov <Hi-Angel@yandex.ru>
(am from https://lore.kernel.org/patchwork/patch/1432184/)

BUG=b:123039911
TEST=Built

Change-Id: I25d9eda8c6bdc7c3653b9f210a159d6c247c81e8
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/kernel/+/2987188
Reviewed-by: Yu Zhao <yuzhao@chromium.org>
Tested-by: Yu Zhao <yuzhao@chromium.org>
Commit-Queue: Sonny Rao <sonnyrao@chromium.org>
Commit-Queue: Yu Zhao <yuzhao@chromium.org>
---
 fs/exec.c                  |   2 +
 include/linux/memcontrol.h |   6 +
 include/linux/mm_types.h   | 107 ++++++++++++
 kernel/exit.c              |   1 +
 kernel/fork.c              |  10 ++
 kernel/sched/core.c        |   2 +
 mm/memcontrol.c            |  28 ++++
 mm/mmu_context.c           |   1 +
 mm/vmscan.c                | 324 +++++++++++++++++++++++++++++++++++++
 9 files changed, 481 insertions(+)

--- a/fs/exec.c
+++ b/fs/exec.c
@@ -1059,6 +1059,7 @@ static int exec_mmap(struct mm_struct *m
 	active_mm = tsk->active_mm;
 	tsk->active_mm = mm;
 	tsk->mm = mm;
+	lru_gen_add_mm(mm);
 	/*
 	 * This prevents preemption while active_mm is being loaded and
 	 * it and mm are being updated, which could cause problems for
@@ -1069,6 +1070,7 @@ static int exec_mmap(struct mm_struct *m
 	if (!IS_ENABLED(CONFIG_ARCH_WANT_IRQS_OFF_ACTIVATE_MM))
 		local_irq_enable();
 	activate_mm(active_mm, mm);
+	lru_gen_switch_mm(active_mm, mm);
 	if (IS_ENABLED(CONFIG_ARCH_WANT_IRQS_OFF_ACTIVATE_MM))
 		local_irq_enable();
 	tsk->mm->vmacache_seqnum = 0;
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -182,6 +182,8 @@ struct memcg_padding {
 #define MEMCG_PADDING(name)
 #endif
 
+struct lru_gen_mm_list;
+
 /*
  * Remember four most recent foreign writebacks with dirty pages in this
  * cgroup.  Inode sharing is expected to be uncommon and, even if we miss
@@ -334,6 +336,10 @@ struct mem_cgroup {
 	struct deferred_split deferred_split_queue;
 #endif
 
+#ifdef CONFIG_LRU_GEN
+	struct lru_gen_mm_list *mm_list;
+#endif
+
 	struct mem_cgroup_per_node *nodeinfo[0];
 	/* WARNING: nodeinfo must be the last member here */
 };
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -14,6 +14,8 @@
 #include <linux/uprobes.h>
 #include <linux/page-flags-layout.h>
 #include <linux/workqueue.h>
+#include <linux/nodemask.h>
+#include <linux/mmdebug.h>
 
 #include <asm/mmu.h>
 
@@ -524,6 +526,22 @@ struct mm_struct {
 		atomic_long_t hugetlb_usage;
 #endif
 		struct work_struct async_put_work;
+#ifdef CONFIG_LRU_GEN
+		struct {
+			/* the node of a global or per-memcg mm_struct list */
+			struct list_head list;
+#ifdef CONFIG_MEMCG
+			/* points to the memcg of the owner task above */
+			struct mem_cgroup *memcg;
+#endif
+			/* whether this mm_struct has been used since the last walk */
+			nodemask_t nodes;
+#ifndef CONFIG_ARCH_WANT_BATCHED_UNMAP_TLB_FLUSH
+			/* the number of CPUs using this mm_struct */
+			atomic_t nr_cpus;
+#endif
+		} lrugen;
+#endif
 	} __randomize_layout;
 
 	/*
@@ -550,6 +568,95 @@ static inline cpumask_t *mm_cpumask(stru
 	return (struct cpumask *)&mm->cpu_bitmap;
 }
 
+#ifdef CONFIG_LRU_GEN
+
+void lru_gen_init_mm(struct mm_struct *mm);
+void lru_gen_add_mm(struct mm_struct *mm);
+void lru_gen_del_mm(struct mm_struct *mm);
+#ifdef CONFIG_MEMCG
+int lru_gen_alloc_mm_list(struct mem_cgroup *memcg);
+void lru_gen_free_mm_list(struct mem_cgroup *memcg);
+void lru_gen_migrate_mm(struct mm_struct *mm);
+#endif
+
+/* Track the usage of each mm_struct so that we can skip inactive ones. */
+static inline void lru_gen_switch_mm(struct mm_struct *old, struct mm_struct *new)
+{
+	/* exclude init_mm, efi_mm, etc. */
+	if (!core_kernel_data((unsigned long)old)) {
+		VM_BUG_ON(old == &init_mm);
+
+		nodes_setall(old->lrugen.nodes);
+#ifndef CONFIG_ARCH_WANT_BATCHED_UNMAP_TLB_FLUSH
+		atomic_dec(&old->lrugen.nr_cpus);
+		VM_BUG_ON_MM(atomic_read(&old->lrugen.nr_cpus) < 0, old);
+#endif
+	} else
+		VM_BUG_ON_MM(READ_ONCE(old->lrugen.list.prev) ||
+			     READ_ONCE(old->lrugen.list.next), old);
+
+	if (!core_kernel_data((unsigned long)new)) {
+		VM_BUG_ON(new == &init_mm);
+
+#ifndef CONFIG_ARCH_WANT_BATCHED_UNMAP_TLB_FLUSH
+		atomic_inc(&new->lrugen.nr_cpus);
+		VM_BUG_ON_MM(atomic_read(&new->lrugen.nr_cpus) < 0, new);
+#endif
+	} else
+		VM_BUG_ON_MM(READ_ONCE(new->lrugen.list.prev) ||
+			     READ_ONCE(new->lrugen.list.next), new);
+}
+
+/* Return whether this mm_struct is being used on any CPUs. */
+static inline bool lru_gen_mm_is_active(struct mm_struct *mm)
+{
+#ifdef CONFIG_ARCH_WANT_BATCHED_UNMAP_TLB_FLUSH
+	return !cpumask_empty(mm_cpumask(mm));
+#else
+	return atomic_read(&mm->lrugen.nr_cpus);
+#endif
+}
+
+#else /* CONFIG_LRU_GEN */
+
+static inline void lru_gen_init_mm(struct mm_struct *mm)
+{
+}
+
+static inline void lru_gen_add_mm(struct mm_struct *mm)
+{
+}
+
+static inline void lru_gen_del_mm(struct mm_struct *mm)
+{
+}
+
+#ifdef CONFIG_MEMCG
+static inline int lru_gen_alloc_mm_list(struct mem_cgroup *memcg)
+{
+	return 0;
+}
+
+static inline void lru_gen_free_mm_list(struct mem_cgroup *memcg)
+{
+}
+
+static inline void lru_gen_migrate_mm(struct mm_struct *mm)
+{
+}
+#endif
+
+static inline void lru_gen_switch_mm(struct mm_struct *old, struct mm_struct *new)
+{
+}
+
+static inline bool lru_gen_mm_is_active(struct mm_struct *mm)
+{
+	return false;
+}
+
+#endif /* CONFIG_LRU_GEN */
+
 struct mmu_gather;
 extern void tlb_gather_mmu(struct mmu_gather *tlb, struct mm_struct *mm,
 				unsigned long start, unsigned long end);
--- a/kernel/exit.c
+++ b/kernel/exit.c
@@ -470,6 +470,7 @@ assign_new_owner:
 		goto retry;
 	}
 	WRITE_ONCE(mm->owner, c);
+	lru_gen_migrate_mm(mm);
 	task_unlock(c);
 	put_task_struct(c);
 }
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -672,6 +672,7 @@ static void check_mm(struct mm_struct *m
 #if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS
 	VM_BUG_ON_MM(mm->pmd_huge_pte, mm);
 #endif
+	VM_BUG_ON_MM(lru_gen_mm_is_active(mm), mm);
 }
 
 #define allocate_mm()	(kmem_cache_alloc(mm_cachep, GFP_KERNEL))
@@ -1045,6 +1046,7 @@ static struct mm_struct *mm_init(struct
 		goto fail_nocontext;
 
 	mm->user_ns = get_user_ns(user_ns);
+	lru_gen_init_mm(mm);
 	return mm;
 
 fail_nocontext:
@@ -1087,6 +1089,7 @@ static inline void __mmput(struct mm_str
 	}
 	if (mm->binfmt)
 		module_put(mm->binfmt->module);
+	lru_gen_del_mm(mm);
 	mmdrop(mm);
 }
 
@@ -2394,6 +2397,13 @@ long _do_fork(struct kernel_clone_args *
 		get_task_struct(p);
 	}
 
+	if (IS_ENABLED(CONFIG_LRU_GEN) && !(clone_flags & CLONE_VM)) {
+		/* lock the task to synchronize with memcg migration */
+		task_lock(p);
+		lru_gen_add_mm(p->mm);
+		task_unlock(p);
+	}
+
 	wake_up_new_task(p);
 
 	/* forking complete and child started to run, tell ptracer */
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -3471,6 +3471,7 @@ context_switch(struct rq *rq, struct tas
 		 * finish_task_switch()'s mmdrop().
 		 */
 		switch_mm_irqs_off(prev->active_mm, next->mm, next);
+		lru_gen_switch_mm(prev->active_mm, next->mm);
 
 		if (!prev->mm) {                        // from kernel
 			/* will mmdrop() in finish_task_switch(). */
@@ -6286,6 +6287,7 @@ void idle_task_exit(void)
 
 	if (mm != &init_mm) {
 		switch_mm(mm, &init_mm, current);
+		lru_gen_switch_mm(mm, &init_mm);
 		finish_arch_post_lock_switch();
 	}
 
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -5102,6 +5102,7 @@ static void __mem_cgroup_free(struct mem
 		free_mem_cgroup_per_node_info(memcg, node);
 	free_percpu(memcg->vmstats_percpu);
 	free_percpu(memcg->vmstats_local);
+	lru_gen_free_mm_list(memcg);
 	kfree(memcg);
 }
 
@@ -5152,6 +5153,9 @@ static struct mem_cgroup *mem_cgroup_all
 		if (alloc_mem_cgroup_per_node_info(memcg, node))
 			goto fail;
 
+	if (lru_gen_alloc_mm_list(memcg))
+		goto fail;
+
 	if (memcg_wb_domain_init(memcg, GFP_KERNEL))
 		goto fail;
 
@@ -6070,6 +6074,29 @@ static void mem_cgroup_move_task(void)
 }
 #endif
 
+#ifdef CONFIG_LRU_GEN
+static void mem_cgroup_attach(struct cgroup_taskset *tset)
+{
+	struct cgroup_subsys_state *css;
+	struct task_struct *task = NULL;
+
+	cgroup_taskset_for_each_leader(task, css, tset)
+		;
+
+	if (!task)
+		return;
+
+	task_lock(task);
+	if (task->mm && task->mm->owner == task)
+		lru_gen_migrate_mm(task->mm);
+	task_unlock(task);
+}
+#else
+static void mem_cgroup_attach(struct cgroup_taskset *tset)
+{
+}
+#endif
+
 /*
  * Cgroup retains root cgroups across [un]mount cycles making it necessary
  * to verify whether we're attached to the default hierarchy on each mount
@@ -6370,6 +6397,7 @@ struct cgroup_subsys memory_cgrp_subsys
 	.css_free = mem_cgroup_css_free,
 	.css_reset = mem_cgroup_css_reset,
 	.can_attach = mem_cgroup_can_attach,
+	.attach = mem_cgroup_attach,
 	.cancel_attach = mem_cgroup_cancel_attach,
 	.post_attach = mem_cgroup_move_task,
 	.bind = mem_cgroup_bind,
--- a/mm/mmu_context.c
+++ b/mm/mmu_context.c
@@ -34,6 +34,7 @@ void use_mm(struct mm_struct *mm)
 	}
 	tsk->mm = mm;
 	switch_mm_irqs_off(active_mm, mm, tsk);
+	lru_gen_switch_mm(active_mm, mm);
 	local_irq_enable();
 	task_unlock(tsk);
 #ifdef finish_arch_post_lock_switch
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -2713,6 +2713,323 @@ static bool positive_ctrl_err(struct con
 }
 
 /******************************************************************************
+ *                          mm_struct list
+ ******************************************************************************/
+
+enum {
+	MM_SCHED_ACTIVE,	/* running processes */
+	MM_SCHED_INACTIVE,	/* sleeping processes */
+	MM_LOCK_CONTENTION,	/* lock contentions */
+	MM_VMA_INTERVAL,	/* VMAs within the range of each PUD/PMD/PTE */
+	MM_LEAF_OTHER_NODE,	/* entries not from the node under reclaim */
+	MM_LEAF_OTHER_MEMCG,	/* entries not from the memcg under reclaim */
+	MM_LEAF_OLD,		/* old entries */
+	MM_LEAF_YOUNG,		/* young entries */
+	MM_LEAF_DIRTY,		/* dirty entries */
+	MM_LEAF_HOLE,		/* non-present entries */
+	MM_NONLEAF_OLD,		/* old non-leaf PMD entries */
+	MM_NONLEAF_YOUNG,	/* young non-leaf PMD entries */
+	NR_MM_STATS
+};
+
+/* mnemonic codes for the stats above */
+#define MM_STAT_CODES		"aicvnmoydhlu"
+
+struct lru_gen_mm_list {
+	/* the head of a global or per-memcg mm_struct list */
+	struct list_head head;
+	/* protects the list */
+	spinlock_t lock;
+	struct {
+		/* set to max_seq after each round of walk */
+		unsigned long cur_seq;
+		/* the next mm on the list to walk */
+		struct list_head *iter;
+		/* to wait for the last worker to finish */
+		struct wait_queue_head wait;
+		/* the number of concurrent workers */
+		int nr_workers;
+		/* stats for debugging */
+		unsigned long stats[NR_STAT_GENS][NR_MM_STATS];
+	} nodes[0];
+};
+
+static struct lru_gen_mm_list *global_mm_list;
+
+static struct lru_gen_mm_list *alloc_mm_list(void)
+{
+	int nid;
+	struct lru_gen_mm_list *mm_list;
+
+	mm_list = kzalloc(struct_size(mm_list, nodes, nr_node_ids), GFP_KERNEL);
+	if (!mm_list)
+		return NULL;
+
+	INIT_LIST_HEAD(&mm_list->head);
+	spin_lock_init(&mm_list->lock);
+
+	for_each_node(nid) {
+		mm_list->nodes[nid].cur_seq = MIN_NR_GENS;
+		mm_list->nodes[nid].iter = &mm_list->head;
+		init_waitqueue_head(&mm_list->nodes[nid].wait);
+	}
+
+	return mm_list;
+}
+
+static struct lru_gen_mm_list *get_mm_list(struct mem_cgroup *memcg)
+{
+#ifdef CONFIG_MEMCG
+	if (!mem_cgroup_disabled())
+		return memcg ? memcg->mm_list : root_mem_cgroup->mm_list;
+#endif
+	VM_BUG_ON(memcg);
+
+	return global_mm_list;
+}
+
+void lru_gen_init_mm(struct mm_struct *mm)
+{
+	INIT_LIST_HEAD(&mm->lrugen.list);
+#ifdef CONFIG_MEMCG
+	mm->lrugen.memcg = NULL;
+#endif
+#ifndef CONFIG_ARCH_WANT_BATCHED_UNMAP_TLB_FLUSH
+	atomic_set(&mm->lrugen.nr_cpus, 0);
+#endif
+	nodes_clear(mm->lrugen.nodes);
+}
+
+void lru_gen_add_mm(struct mm_struct *mm)
+{
+	struct mem_cgroup *memcg = get_mem_cgroup_from_mm(mm);
+	struct lru_gen_mm_list *mm_list = get_mm_list(memcg);
+
+	VM_BUG_ON_MM(!list_empty(&mm->lrugen.list), mm);
+#ifdef CONFIG_MEMCG
+	VM_BUG_ON_MM(mm->lrugen.memcg, mm);
+	WRITE_ONCE(mm->lrugen.memcg, memcg);
+#endif
+	spin_lock(&mm_list->lock);
+	list_add_tail(&mm->lrugen.list, &mm_list->head);
+	spin_unlock(&mm_list->lock);
+}
+
+void lru_gen_del_mm(struct mm_struct *mm)
+{
+	int nid;
+#ifdef CONFIG_MEMCG
+	struct lru_gen_mm_list *mm_list = get_mm_list(mm->lrugen.memcg);
+#else
+	struct lru_gen_mm_list *mm_list = get_mm_list(NULL);
+#endif
+
+	spin_lock(&mm_list->lock);
+
+	for_each_node(nid) {
+		if (mm_list->nodes[nid].iter != &mm->lrugen.list)
+			continue;
+
+		mm_list->nodes[nid].iter = mm_list->nodes[nid].iter->next;
+		if (mm_list->nodes[nid].iter == &mm_list->head)
+			WRITE_ONCE(mm_list->nodes[nid].cur_seq,
+				   mm_list->nodes[nid].cur_seq + 1);
+	}
+
+	list_del_init(&mm->lrugen.list);
+
+	spin_unlock(&mm_list->lock);
+
+#ifdef CONFIG_MEMCG
+	mem_cgroup_put(mm->lrugen.memcg);
+	WRITE_ONCE(mm->lrugen.memcg, NULL);
+#endif
+}
+
+#ifdef CONFIG_MEMCG
+int lru_gen_alloc_mm_list(struct mem_cgroup *memcg)
+{
+	if (mem_cgroup_disabled())
+		return 0;
+
+	memcg->mm_list = alloc_mm_list();
+
+	return memcg->mm_list ? 0 : -ENOMEM;
+}
+
+void lru_gen_free_mm_list(struct mem_cgroup *memcg)
+{
+	kfree(memcg->mm_list);
+	memcg->mm_list = NULL;
+}
+
+void lru_gen_migrate_mm(struct mm_struct *mm)
+{
+	struct mem_cgroup *memcg;
+
+	lockdep_assert_held(&mm->owner->alloc_lock);
+
+	if (mem_cgroup_disabled())
+		return;
+
+	rcu_read_lock();
+	memcg = mem_cgroup_from_task(mm->owner);
+	rcu_read_unlock();
+	if (memcg == mm->lrugen.memcg)
+		return;
+
+	VM_BUG_ON_MM(!mm->lrugen.memcg, mm);
+	VM_BUG_ON_MM(list_empty(&mm->lrugen.list), mm);
+
+	lru_gen_del_mm(mm);
+	lru_gen_add_mm(mm);
+}
+
+static bool mm_has_migrated(struct mm_struct *mm, struct mem_cgroup *memcg)
+{
+	return READ_ONCE(mm->lrugen.memcg) != memcg;
+}
+#else
+static bool mm_has_migrated(struct mm_struct *mm, struct mem_cgroup *memcg)
+{
+	return false;
+}
+#endif
+
+struct mm_walk_args {
+	struct mem_cgroup *memcg;
+	unsigned long max_seq;
+	unsigned long start_pfn;
+	unsigned long end_pfn;
+	unsigned long next_addr;
+	int node_id;
+	int swappiness;
+	int batch_size;
+	int nr_pages[MAX_NR_GENS][ANON_AND_FILE][MAX_NR_ZONES];
+	int mm_stats[NR_MM_STATS];
+	unsigned long bitmap[0];
+};
+
+static int size_of_mm_walk_args(void)
+{
+	int size = sizeof(struct mm_walk_args);
+
+	if (IS_ENABLED(CONFIG_TRANSPARENT_HUGEPAGE) ||
+	    IS_ENABLED(CONFIG_HAVE_ARCH_PARENT_PMD_YOUNG))
+		size += sizeof(unsigned long) * BITS_TO_LONGS(PTRS_PER_PMD);
+
+	return size;
+}
+
+static void reset_mm_stats(struct lru_gen_mm_list *mm_list, bool last,
+			   struct mm_walk_args *args)
+{
+	int i;
+	int nid = args->node_id;
+	int hist = hist_from_seq_or_gen(args->max_seq);
+
+	lockdep_assert_held(&mm_list->lock);
+
+	for (i = 0; i < NR_MM_STATS; i++) {
+		WRITE_ONCE(mm_list->nodes[nid].stats[hist][i],
+			   mm_list->nodes[nid].stats[hist][i] + args->mm_stats[i]);
+		args->mm_stats[i] = 0;
+	}
+
+	if (!last || NR_STAT_GENS == 1)
+		return;
+
+	hist = hist_from_seq_or_gen(args->max_seq + 1);
+	for (i = 0; i < NR_MM_STATS; i++)
+		WRITE_ONCE(mm_list->nodes[nid].stats[hist][i], 0);
+}
+
+static bool should_skip_mm(struct mm_struct *mm, struct mm_walk_args *args)
+{
+	int type;
+	unsigned long size = 0;
+
+	if (!lru_gen_mm_is_active(mm) && !node_isset(args->node_id, mm->lrugen.nodes))
+		return true;
+
+	if (mm_is_oom_victim(mm))
+		return true;
+
+	for (type = !args->swappiness; type < ANON_AND_FILE; type++) {
+		size += type ? get_mm_counter(mm, MM_FILEPAGES) :
+			       get_mm_counter(mm, MM_ANONPAGES) +
+			       get_mm_counter(mm, MM_SHMEMPAGES);
+	}
+
+	/* leave the legwork to the rmap if mappings are too sparse */
+	if (size < max(SWAP_CLUSTER_MAX, mm_pgtables_bytes(mm) / PAGE_SIZE))
+		return true;
+
+	return !mmget_not_zero(mm);
+}
+
+/* To support multiple workers that concurrently walk an mm_struct list. */
+static bool get_next_mm(struct mm_walk_args *args, struct mm_struct **iter)
+{
+	bool last = true;
+	struct mm_struct *mm = NULL;
+	int nid = args->node_id;
+	struct lru_gen_mm_list *mm_list = get_mm_list(args->memcg);
+
+	if (*iter)
+		mmput_async(*iter);
+	else if (args->max_seq <= READ_ONCE(mm_list->nodes[nid].cur_seq))
+		return false;
+
+	spin_lock(&mm_list->lock);
+
+	VM_BUG_ON(args->max_seq > mm_list->nodes[nid].cur_seq + 1);
+	VM_BUG_ON(*iter && args->max_seq < mm_list->nodes[nid].cur_seq);
+	VM_BUG_ON(*iter && !mm_list->nodes[nid].nr_workers);
+
+	if (args->max_seq <= mm_list->nodes[nid].cur_seq) {
+		last = *iter;
+		goto done;
+	}
+
+	if (mm_list->nodes[nid].iter == &mm_list->head) {
+		VM_BUG_ON(*iter || mm_list->nodes[nid].nr_workers);
+		mm_list->nodes[nid].iter = mm_list->nodes[nid].iter->next;
+	}
+
+	while (!mm && mm_list->nodes[nid].iter != &mm_list->head) {
+		mm = list_entry(mm_list->nodes[nid].iter, struct mm_struct, lrugen.list);
+		mm_list->nodes[nid].iter = mm_list->nodes[nid].iter->next;
+		if (should_skip_mm(mm, args))
+			mm = NULL;
+
+		args->mm_stats[mm ? MM_SCHED_ACTIVE : MM_SCHED_INACTIVE]++;
+	}
+
+	if (mm_list->nodes[nid].iter == &mm_list->head)
+		WRITE_ONCE(mm_list->nodes[nid].cur_seq,
+			   mm_list->nodes[nid].cur_seq + 1);
+done:
+	if (*iter && !mm)
+		mm_list->nodes[nid].nr_workers--;
+	if (!*iter && mm)
+		mm_list->nodes[nid].nr_workers++;
+
+	last = last && !mm_list->nodes[nid].nr_workers &&
+	       mm_list->nodes[nid].iter == &mm_list->head;
+
+	reset_mm_stats(mm_list, last, args);
+
+	spin_unlock(&mm_list->lock);
+
+	*iter = mm;
+	if (mm)
+		node_clear(nid, mm->lrugen.nodes);
+
+	return last;
+}
+
+/******************************************************************************
  *                          state change
  ******************************************************************************/
 
@@ -2940,6 +3257,13 @@ static int __init init_lru_gen(void)
 {
 	BUILD_BUG_ON(MIN_NR_GENS + 1 >= MAX_NR_GENS);
 	BUILD_BUG_ON(BIT(LRU_GEN_WIDTH) <= MAX_NR_GENS);
+	BUILD_BUG_ON(sizeof(MM_STAT_CODES) != NR_MM_STATS + 1);
+
+	if (mem_cgroup_disabled()) {
+		global_mm_list = alloc_mm_list();
+		if (WARN_ON_ONCE(!global_mm_list))
+			return -ENOMEM;
+	}
 
 	if (hotplug_memory_notifier(lru_gen_online_mem, 0))
 		pr_err("lru_gen: failed to subscribe hotplug notifications\n");
