如何使用 JPA 实现分页查询并返回 VO 对象

JPA分页踩坑指南

1.原生sql查询返回vo类包含主键id,无法自动映射,需要用到投影ResultTransformer,所以我定义了一个投影工具类

JpaCommonService

2.异步调用原生查询方法的时候,需要用

NativeQuery<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQuery.class);,不能用NativeQueryImpl<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQueryImpl.class);

3.用jpql查询entityManager.createQuery()分页查询,如分页会涉及到子查询是不支持的,只能用原生sql, 不支持select count(1) from (select * from User) t


以下是分页返回vo的具体实现

分页返回vo

下面是用到的一些类和方法

分页参数

package com.example.springbootjpadruid.domain.common;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class BasePage {

    private int page = 1;

    private int size = 10;

    /**
     * 是否需要分页
     */
    private boolean needPage = false;

    public int getPage() {
        if (page < 1) {
            return 0;
        } else {
            return page - 1;
        }
    }

    public int getOffset() {
        return getPage() * size;
    }
}

分页返回结果

package com.example.springbootjpadruid.domain.common;

import cn.hutool.core.bean.BeanUtil;
import lombok.Data;
import org.springframework.data.domain.Page;

import java.util.List;

@Data
public class PageVO<T> {

    private int page = 1;

    private int size = 10;

    private long total;

    private int totalPage;

    private List<T> list;


    public static <S, T> PageVO<T> of(Page<S> page, Class<T> clazz) {
        PageVO<T> pageVO = new PageVO<>();
        pageVO.setPage(page.getNumber() + 1);
        pageVO.setSize(page.getSize());
        pageVO.setTotal(page.getTotalElements());
        List<S> content = page.getContent();
        pageVO.setList(BeanUtil.copyToList(content, clazz));
        pageVO.setTotalPage(page.getTotalPages());
        return pageVO;
    }

    public static <T> PageVO<T> of(Page<T> page) {
        PageVO<T> pageVO = new PageVO<>();
        pageVO.setPage(page.getNumber() + 1);
        pageVO.setSize(page.getSize());
        pageVO.setTotal(page.getTotalElements());
        pageVO.setList(page.getContent());
        pageVO.setTotalPage(page.getTotalPages());
        return pageVO;
    }
}

原生sql查询通用方法

package com.example.springbootjpadruid.service;

import com.alibaba.fastjson.JSON;
import com.example.springbootjpadruid.domain.common.BasePage;
import org.hibernate.query.NativeQuery;
import org.hibernate.transform.ResultTransformer;
import org.springframework.beans.SimpleTypeConverter;
import org.springframework.stereotype.Service;

import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Query;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Method;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * @author Jonny
 * @description 原生sql查询,自动将下划线转驼峰
 */
@Service
public class JpaCommonService {

    private static final SimpleTypeConverter CONVERTER = new SimpleTypeConverter();

    @PersistenceContext
    private EntityManager entityManager;

    /**
     * 查询数据集合
     *
     * @param sql    查询sql,参数用:name格式
     * @param params 查询参数map格式,key对应参数中的:name
     * @param clazz  实体类型,为空则直接转换为map格式
     * @return list
     */
    @SuppressWarnings("unchecked")
    public <P extends BasePage> List<?> queryList(String sql, P params, Class<?> clazz) {
        String jsonParam = JSON.toJSONString(params);
        Map<String, Object> mapParams = JSON.parseObject(jsonParam, Map.class);
        // 提取 SQL 中的参数并过滤无效参数
        Set<String> sqlParams = extractSqlParams(sql);
        Map<String, Object> filteredParams = filterParams(sqlParams, mapParams);
        return queryList(sql, mapParams, filteredParams, clazz);
    }

    /**
     * 总记录数
     *
     * @param sql    查询sql,参数用:name格式
     * @param params 查询参数map格式,key对应参数中的:name
     * @return list
     */
    @SuppressWarnings("unchecked")
    public <P extends BasePage> long count(String sql, P params) {
        String jsonParam = JSON.toJSONString(params);
        Map<String, Object> mapParams = JSON.parseObject(jsonParam, Map.class);
        // 提取 SQL 中的参数并过滤无效参数
        Set<String> sqlParams = extractSqlParams(sql);
        Map<String, Object> filteredParams = filterParams(sqlParams, mapParams);
        return count(sql, mapParams, filteredParams, Long.class);
    }

    /**
     * 查询数据集合
     *
     * @param sql            查询sql,参数用:name格式
     * @param params         查询参数map格式,key对应参数中的:name
     * @param filteredParams sql中使用的参数
     * @param clazz          实体类型,为空则直接转换为map格式
     * @return list
     */
    @SuppressWarnings("all")
    public List<?> queryList(String sql, Map<String, Object> params, Map<String, Object> extractSqlParams, Class<?> clazz) {
        try {
            NativeQuery<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQuery.class);
            if (Objects.nonNull(extractSqlParams)) {
                extractSqlParams.forEach((k, v) -> query.setParameter(k, v));
            }
            needPage(params, query);
            if (Objects.isNull(clazz)) {
                query.unwrap(NativeQuery.class).setResultTransformer(new AliasToEntityMapResultTransformer());
                return query.getResultList();
            } else {
                query.unwrap(NativeQuery.class).setResultTransformer(new AliasToBeanResultTransformer(clazz));
                return query.getResultList();
            }
        } catch (Exception e) {
            throw new RuntimeException("查询或转换数据时出错", e);
        }
    }

    @SuppressWarnings("all")
    public long count(String countSql, Map<String, Object> params, Map<String, Object> filteredParams, Class<?> clazz) {
        long total;
        try {
            Query countQuery = entityManager.createNativeQuery(countSql);
            if (Objects.nonNull(filteredParams)) {
                filteredParams.forEach(countQuery::setParameter);
            }
            total = Long.parseLong(countQuery.getSingleResult().toString());
        } catch (Exception e) {
            throw new RuntimeException("查询记录数出错", e);
        }
        return total;
    }

    /**
     * 判断是否需要分页
     * 当需要分页时,需要设置setNeedPage(true),并设置offset和size
     *
     * @param params 查询参数
     * @param query  查询对象
     */
    private void needPage(Map<String, Object> params, Query query) {
        if (Boolean.TRUE.equals(params.get("needPage"))) {
            int offset = (int) params.get("offset");
            int size = (int) params.get("size");
            query.setFirstResult(offset);
            query.setMaxResults(size);
        }
    }

    /**
     * 从 SQL 中提取所有命名参数
     *
     * @param sql SQL 查询
     * @return 参数名集合
     */
    private static Set<String> extractSqlParams(String sql) {
        Pattern pattern = Pattern.compile(":(\\w+)");
        Matcher matcher = pattern.matcher(sql);

        Set<String> params = new HashSet<>();
        while (matcher.find()) {
            params.add(matcher.group(1));
        }
        return params;
    }

    /**
     * 根据 SQL 中的参数过滤多余的 Map 参数
     *
     * @param sqlParams SQL 中使用的参数
     * @param params    原始参数
     * @return 过滤后的参数
     */
    private static Map<String, Object> filterParams(Set<String> sqlParams, Map<String, Object> params) {
        if (params == null || params.isEmpty()) {
            return Collections.emptyMap();
        }
        return params.entrySet().stream().filter(entry -> sqlParams.contains(entry.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    private static class AliasToEntityMapResultTransformer implements ResultTransformer {

        @Override
        public Object transformTuple(Object[] tuple, String[] aliases) {
            Map<String, Object> result = new HashMap<>();
            for (int i = 0; i < aliases.length; i++) {
                if (aliases[i] != null) {
                    result.put(aliases[i].toLowerCase(), tuple[i]);
                }
            }
            return result;
        }

        @Override
        public List transformList(List collection) {
            return collection;
        }
    }

    private static class AliasToBeanResultTransformer implements ResultTransformer {

        private final Class<?> resultClass;

        public AliasToBeanResultTransformer(Class<?> resultClass) {
            this.resultClass = resultClass;
        }

        /**
         * 转换结果集 自定将下划线转驼峰
         *
         * @param tuple   结果集
         * @param aliases 字段名
         * @return 实体对象
         */
        @Override
        public Object transformTuple(Object[] tuple, String[] aliases) {
            try {
                Object result = resultClass.getDeclaredConstructor().newInstance();
                PropertyDescriptor[] props = Introspector.getBeanInfo(resultClass).getPropertyDescriptors();
                Map<String, Method> writeMethodMap = Arrays.stream(props).filter(p -> Objects.nonNull(p.getWriteMethod())).collect(Collectors.toMap(p -> p.getName().toLowerCase(), PropertyDescriptor::getWriteMethod));

                for (int i = 0; i < aliases.length; i++) {
                    if (aliases[i] == null) {
                        continue;
                    }
                    String fieldName = aliases[i].toLowerCase().replace("_", "");
                    Method writeMethod = writeMethodMap.get(fieldName);
                    if (writeMethod != null) {
                        Object value = CONVERTER.convertIfNecessary(tuple[i], writeMethod.getParameterTypes()[0]);
                        writeMethod.invoke(result, value);
                    }
                }
                return result;
            } catch (Exception e) {
                throw new RuntimeException("实体映射错误: " + resultClass.getName(), e);
            }
        }

        @Override
        public List transformList(List collection) {
            return collection;
        }
    }
}

使用方法

package com.example.springbootjpadruid.service.impl;

import com.example.springbootjpadruid.domain.common.PageVO;
import com.example.springbootjpadruid.domain.entity.primary.User;
import com.example.springbootjpadruid.domain.query.UserQuery;
import com.example.springbootjpadruid.domain.vo.UserVO;
import com.example.springbootjpadruid.repository.primary.UserRepository;
import com.example.springbootjpadruid.service.JpaCommonService;
import com.example.springbootjpadruid.service.UserService;
import com.github.wenhao.jpa.PredicateBuilder;
import com.github.wenhao.jpa.Specifications;
import org.hibernate.query.NativeQuery;
import org.hibernate.transform.ResultTransformer;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Query;
import java.util.List;

@Service
public class UserServiceImpl implements UserService {

    @Resource
    private UserRepository repository;

    @PersistenceContext
    private EntityManager entityManager;

    @Resource
    private JpaCommonService jpaCommonService;

    @Override
    public String init(List<User> users) {
        repository.saveAll(users);
        return "ok";
    }

    @Override
    public User getUser() {
        PredicateBuilder<User> where = Specifications.and();
        where.eq("username", "zhangsan");
        where.eq("age", 20);
        return repository.findOne(where.build()).orElse(null);
    }

    /**
     * 分页1
     *
     * @param param
     * @return
     */
    @Override
    public PageVO<UserVO> listPage(UserQuery param) {
        PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());
        Page<User> page = repository.findAll(pageable);
        return PageVO.of(page, UserVO.class);
    }

    /**
     * 分页2
     *
     * @param param 查询参数
     * @return 分页数据
     */
    @SuppressWarnings("all")
    @Override
    public PageVO<UserVO> nativeListPage(UserQuery param) {
        StringBuffer sql = new StringBuffer("select id as id, username as username, age as age, address as address from sys_user where 1 = 1");
        buildWhere(param, sql);

        StringBuffer countSql = new StringBuffer("select count(1) from sys_user where 1 = 1");
        buildWhere(param, countSql);

        PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());

        NativeQuery query = entityManager.createNativeQuery(sql.toString()).unwrap(NativeQuery.class);
        query.setFirstResult(param.getOffset());
        query.setMaxResults(param.getSize());
        // 主键id无法直接映射,需要手动设置
        // query.setResultTransformer(Transformers.aliasToBean(UserVO.class));

        // 由于返回数据包含主键id,无法直接映射,只能借助投影手动映射
        query.setResultTransformer(new ResultTransformer() {
            @Override
            public Object transformTuple(Object[] objects, String[] strings) {
                UserVO userVO = new UserVO();
                userVO.setId(Long.parseLong(objects[0].toString()));
                userVO.setUsername(objects[1].toString());
                userVO.setAge(Integer.parseInt(objects[2].toString()));
                userVO.setAddress(objects[3].toString());
                return userVO;
            }

            @Override
            public List transformList(List list) {
                return list;
            }
        });
        buildParam(param, query);

        Query countQuery = entityManager.createNativeQuery(countSql.toString());
        buildParam(param, countQuery);
        long total = Long.parseLong(countQuery.getSingleResult().toString());

        PageImpl<UserVO> pageImpl = new PageImpl<>(query.getResultList(), pageable, total);
        return PageVO.of(pageImpl);
    }

    /**
     * 分页3
     * 原生sql自定转驼峰,带下划线不管大小自动转标准驼峰,全大写会转小写然后转标准驼峰
     * 解决分页方法2的原生sql如果没用as别名,无法自动转驼峰的问题
     *
     * @param param
     * @return
     */
    @SuppressWarnings("all")
    @Override
    public PageVO<UserVO> listPageByJpaUtil(UserQuery param) {
        // StringBuffer sql = new StringBuffer("select id, username, age, address, user_role as USER_ROLE from sys_user where 1 = 1");
        StringBuffer sql = new StringBuffer("select * from sys_user where 1 = 1");
        buildWhere(param, sql);

        StringBuffer countSql = new StringBuffer("select count(1) from sys_user where 1 = 1");
        buildWhere(param, countSql);

        PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());
        param.setNeedPage(Boolean.TRUE);
        List<UserVO> userList = (List<UserVO>) jpaCommonService.queryList(sql.toString(), param, UserVO.class);
        long total = jpaCommonService.count(countSql.toString(), param);

        PageImpl<UserVO> pageImpl = new PageImpl<>(userList, pageable, total);
        return PageVO.of(pageImpl);
    }

    private void buildParam(UserQuery param, Query query) {
        if (StringUtils.hasText(param.getUsername())) {
            query.setParameter("username", param.getUsername());
        }
    }

    private void buildWhere(UserQuery param, StringBuffer sql) {
        if (StringUtils.hasText(param.getUsername())) {
            sql.append(" and username = :username");
        }
    }

}
1 个赞