// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.spring.contexts.model;

import com.intellij.jam.JamService;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.module.Module;
import com.intellij.openapi.util.NotNullLazyValue;
import com.intellij.psi.PsiClass;
import com.intellij.psi.PsiPackage;
import com.intellij.psi.search.GlobalSearchScope;
import com.intellij.psi.search.PackageScope;
import com.intellij.psi.util.PsiModificationTracker;
import com.intellij.spring.model.BeanService;
import com.intellij.spring.model.CommonSpringBean;
import com.intellij.spring.model.SpringBeanPointer;
import com.intellij.spring.model.SpringQualifier;
import com.intellij.spring.model.jam.JamPsiMemberSpringBean;
import com.intellij.spring.model.jam.SpringJamModel;
import com.intellij.spring.model.jam.javaConfig.ImplicitlyRegisteredBeansProvider;
import com.intellij.spring.model.jam.javaConfig.SpringJavaBean;
import com.intellij.spring.model.jam.stereotype.SpringConfiguration;
import com.intellij.spring.model.jam.stereotype.SpringStereotypeElement;
import com.intellij.spring.model.jam.utils.SpringJamUtils;
import com.intellij.spring.model.jam.utils.filters.SpringContextFilter;
import com.intellij.spring.model.utils.SpringProfileUtils;
import com.intellij.spring.model.xml.context.SpringBeansPackagesScan;
import com.intellij.util.NotNullFunction;
import com.intellij.util.SmartList;
import com.intellij.util.containers.ConcurrentFactoryMap;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.TestOnly;

import java.util.*;

public class ComponentScanPackagesModel extends CacheableCommonSpringModel {

  private final NotNullLazyValue<? extends Set<PsiPackage>> myPackages;
  private volatile Collection<SpringBeanPointer<?>> myScannedBeans;
  private final Map<SpringQualifier, List<SpringBeanPointer<?>>> myLocalBeansByQualifier;

  @NotNull
  private final Module myModule;

  public ComponentScanPackagesModel(@NotNull NotNullLazyValue<? extends Set<PsiPackage>> packages, @NotNull Module module) {
    myPackages = packages;
    myModule = module;
    myLocalBeansByQualifier = ConcurrentFactoryMap.createMap(key -> findLocalBeansByQualifier(this, key));
  }

  @Override
  public Collection<SpringBeanPointer<?>> getLocalBeans() {
    return myScannedBeans == null ? myScannedBeans = calculateLocalBeans() : myScannedBeans;
  }

  protected final Collection<SpringBeanPointer<?>> calculateLocalBeans() {
    Collection<SpringBeanPointer<?>> pointers = calculateScannedBeans();

    Set<CommonSpringBean> javaBeans = new LinkedHashSet<>();
    for (SpringBeanPointer<?> pointer : pointers) {
      CommonSpringBean springBean = pointer.getSpringBean();
      if (springBean instanceof SpringStereotypeElement) {
        SpringStereotypeElement stereotypeElement = (SpringStereotypeElement)springBean;
        PsiClass psiClass = stereotypeElement.getPsiElement();
        if (JamService.getJamService(psiClass.getProject())
              .getJamElement(JamPsiMemberSpringBean.PSI_MEMBER_SPRING_BEAN_JAM_KEY, psiClass) != null) {
          continue;
        }
        for (SpringJavaBean springJavaBean : stereotypeElement.getBeans()) {
          javaBeans.add(springJavaBean);
          if (springJavaBean instanceof ImplicitlyRegisteredBeansProvider) {
            javaBeans.addAll(((ImplicitlyRegisteredBeansProvider)springJavaBean).getImplicitlyRegistered());
          }
        }
      }
    }

    Set<SpringBeanPointer<?>> beans = new LinkedHashSet<>();
    beans.addAll(pointers);
    beans.addAll(BeanService.getInstance().mapSpringBeans(javaBeans));

    return beans;
  }

  protected Collection<SpringBeanPointer<?>> calculateScannedBeans() {
    return getScannedComponents(myPackages.getValue(), getModule(), getActiveProfiles());
  }

  public static Collection<SpringBeanPointer<?>> getScannedComponents(@NotNull Set<PsiPackage> packages,
                                                                      @NotNull Module module,
                                                                      @Nullable Set<String> profiles) {
    return getScannedComponents(packages, module, profiles, true,
                                Collections.emptySet(),
                                Collections.emptySet());
  }

  public static Collection<SpringConfiguration> getScannedConfigurations(SpringBeansPackagesScan scan,
                                                                         @NotNull final Module module,
                                                                         @Nullable Set<String> profiles) {

    List<CommonSpringBean> components = getScannedComponents(dom -> SpringJamModel.getModel(module).getConfigurations(dom),
                                                             scan.getPsiPackages(), module, profiles, scan.useDefaultFilters(),
                                                             scan.getExcludeContextFilters(),
                                                             scan.getIncludeContextFilters());

    Set<SpringConfiguration> configurations = new LinkedHashSet<>();
    for (CommonSpringBean component : components) {
      if (component instanceof SpringConfiguration) {
        configurations.add((SpringConfiguration)component);
      }
    }
    return configurations;
  }

  public static Collection<SpringBeanPointer<?>> getScannedComponents(@NotNull Set<PsiPackage> packages,
                                                                      @NotNull final Module module,
                                                                      @Nullable Set<String> profiles,
                                                                      boolean useDefaultFilters,
                                                                      @NotNull Set<SpringContextFilter.Exclude> excludeContextFilters,
                                                                      @NotNull Set<SpringContextFilter.Include> includeContextFilters) {
    Collection<CommonSpringBean> components =
      getScannedComponents(scope -> SpringJamModel.getModel(module).getStereotypeComponents(scope),
                           packages, module, profiles,
                           useDefaultFilters, excludeContextFilters, includeContextFilters);

    return BeanService.getInstance().mapSpringBeans(components);
  }

  public static <T extends SpringStereotypeElement> List<CommonSpringBean> getScannedComponents(@NotNull NotNullFunction<? super GlobalSearchScope, ? extends List<T>> components,
                                                                                                @NotNull Set<PsiPackage> packages,
                                                                                                @NotNull Module module,
                                                                                                @Nullable Set<String> profiles,
                                                                                                boolean useDefaultFilters,
                                                                                                @NotNull Set<SpringContextFilter.Exclude> excludeContextFilters,
                                                                                                @NotNull Set<SpringContextFilter.Include> includeContextFilters) {
    if (module.isDisposed() || packages.isEmpty()) return new SmartList<>();

    GlobalSearchScope[] scopes = packages.stream()
      .peek(psiPackage -> {
        if (psiPackage.getQualifiedName().isEmpty() &&
            ApplicationManager.getApplication().isUnitTestMode() &&
            !ourAllowDefaultPackageForTests) {
          throw new IllegalArgumentException("Do not use component-scan with <default> package in tests");
        }
      })
      .map(psiPackage -> PackageScope.packageScope(psiPackage, true))
      .toArray(GlobalSearchScope[]::new);

    GlobalSearchScope allPackagesUnionScope = GlobalSearchScope.union(scopes);

    final GlobalSearchScope moduleScope = GlobalSearchScope.moduleWithDependenciesAndLibrariesScope(module);
    final GlobalSearchScope effectiveSearchScope = moduleScope.intersectWith(allPackagesUnionScope);

    final List<T> allPointers = components.fun(effectiveSearchScope);

    Set<CommonSpringBean> filteredBeans =
      SpringJamUtils.getInstance().filterComponentScannedStereotypes(module, allPointers, packages, useDefaultFilters,
                                                                     excludeContextFilters, includeContextFilters);

    return SpringProfileUtils.filterBeansInActiveProfiles(filteredBeans, profiles);
  }

  private static boolean ourAllowDefaultPackageForTests = false;

  /**
   * Override assertion to use default package in tests.
   * This is disabled by default to avoid performance problems.
   */
  @TestOnly
  public static void setAllowDefaultPackageForTests(boolean value) {
    ourAllowDefaultPackageForTests = value;
  }

  @NotNull
  @Override
  public Module getModule() {
    return myModule;
  }

  @Override
  protected Collection<Object> getCachingProcessorsDependencies() {return Collections.singleton(PsiModificationTracker.MODIFICATION_COUNT);}

  @NotNull
  @Override
  public List<SpringBeanPointer<?>> findQualified(@NotNull SpringQualifier qualifier) {
    return myLocalBeansByQualifier.get(qualifier);
  }
}
