/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

/**
 * This file has classes to combine New Tab feature events (aggregated from a sqlLite table) into an interest model.
 */

import {
  FORMAT,
  AggregateResultKeys,
  SPECIAL_FEATURE_CLICK,
} from "resource://newtab/lib/InferredModel/InferredConstants.sys.mjs";

export const DAYS_TO_MS = 60 * 60 * 24 * 1000;

const MAX_INT_32 = 2 ** 32;

/**
 * Divides numerator fields by the denominator. Value is set to 0 if denominator is missing or 0.
 * Adds 0 value for all situations where there is a denominator but no numerator value.
 *
 * @param {{[key: string]: number}} numerator
 * @param {{[key: string]: number}} denominator
 * @returns {{[key: string]: number}}
 */
export function divideDict(numerator, denominator) {
  const result = {};
  Object.keys(numerator).forEach(k => {
    result[k] = denominator[k] ? numerator[k] / denominator[k] : 0;
  });
  Object.keys(denominator).forEach(k => {
    if (!(k in result)) {
      result[k] = 0.0;
    }
  });
  return result;
}

/**
 * Unary encoding with randomized response for differential privacy.
 * The output must be decoded to back to an integer when aggregating a historgram on a server
 *
 * @param {number} x - Integer input (0 <= x < N)
 * @param {number} N - Number of values (see ablove)
 * @param {number} p - Probability of keeping a 1-bit as 1 (after one-hot encoding the output)
 * @param {number} q - Probability of flipping a 0-bit to 1
 * @returns {string} - Bitstring after unary encoding and randomized response
 */
export function unaryEncodeDiffPrivacy(x, N, p, q) {
  const bitstring = [];
  const randomValues = new Uint32Array(N);
  crypto.getRandomValues(randomValues);
  for (let i = 0; i < N; i++) {
    const trueBit = i === x ? 1 : 0;
    const rand = randomValues[i] / MAX_INT_32;
    if (trueBit === 1) {
      bitstring.push(rand <= p ? "1" : "0");
    } else {
      bitstring.push(rand <= q ? "1" : "0");
    }
  }
  return bitstring.join("");
}

/**
 * Adds value to all a particular key in a dictionary. If the key is missing it sets the value.
 *
 * @param {object} dict - The dictionary to modify.
 * @param {string} key - The key whose value should be added or set.
 * @param {number} value - The value to add to the key.
 */
export function dictAdd(dict, key, value) {
  if (key in dict) {
    dict[key] += value;
  } else {
    dict[key] = value;
  }
}

/**
 * Apply function to all keys in dictionary, returning new dictionary.
 *
 * @param {object} obj - The object whose values should be transformed.
 * @param {Function} fn - The function to apply to each value.
 * @returns {object} A new object with the transformed values.
 */
export function dictApply(obj, fn) {
  return Object.fromEntries(
    Object.entries(obj).map(([key, value]) => [key, fn(value)])
  );
}

/**
 * Class for re-scaling events based on time passed.
 */
export class DayTimeWeighting {
  /**
   * Instantiate class based on a series of day periods in the past.
   *
   * @param {int[]} pastDays Series of number of days, indicating days ago intervals in reverse chonological order.
   * Intervals are added: If the first value is 1 and the second is 5, then the first inteval is 0-1 and second interval is between 1 and 6.
   * @param {number[]} relativeWeight Relative weight of each period. Must be same length as pastDays
   */
  constructor(pastDays, relativeWeight) {
    this.pastDays = pastDays;
    this.relativeWeight = relativeWeight;
  }

  static fromJSON(json) {
    return new DayTimeWeighting(json.days, json.relative_weight);
  }

  /**
   * Get a series of interval pairs in the past based on the pastDays.
   *
   * @param {number} curTimeMs Base time time in MS. Usually current time.
   * @returns
   */
  getDateIntervals(curTimeMs) {
    let curEndTime = curTimeMs;

    const res = this.pastDays.map(daysAgo => {
      const start = new Date(curEndTime - daysAgo * DAYS_TO_MS);
      const end = new Date(curEndTime);

      curEndTime = start;
      return { start, end };
    });
    return res;
  }

  /**
   * Get relative weight of current index.
   *
   * @param {int} weightIndex Index
   * @returns {number} Weight at index, or 0 if index out of range.
   */
  getRelativeWeight(weightIndex) {
    if (weightIndex >= this.pastDays.length) {
      return 0;
    }
    return this.relativeWeight[weightIndex];
  }
}

/**
 * Describes the mapping from a set of aggregated events to a single interest feature
 */
export class InterestFeatures {
  constructor(
    name,
    featureWeights,
    thresholds = null,
    diff_p = 0.5,
    diff_q = 0.5
  ) {
    this.name = name;
    this.featureWeights = featureWeights;
    // Thresholds must be in ascending order
    this.thresholds = thresholds;
    this.diff_p = diff_p;
    this.diff_q = diff_q;
  }

  static fromJSON(name, json) {
    return new InterestFeatures(
      name,
      json.features,
      json.thresholds || null,
      json.diff_p,
      json.diff_q
    );
  }

  /**
   * Quantize a feature value based on the thresholds specified in the class.
   *
   * @param {number} inValue Value computed by model for the feature.
   * @returns Quantized value. A value between 0 and number of thresholds specified (inclusive)
   */
  applyThresholds(inValue) {
    if (!this.thresholds) {
      return inValue;
    }
    for (let k = 0; k < this.thresholds.length; k++) {
      if (inValue < this.thresholds[k]) {
        return k;
      }
    }
    return this.thresholds.length;
  }

  /**
   * Applies Differential Privacy Unary Encoding method, outputting a one-hot encoded vector with randomizaiton.
   * Accurate historgrams of values can be computed with reasonable accuracy.
   * If the class has no or 0 p/q values set for differential privacy, then response is original number non-encoded.
   *
   * @param {number} inValue Value to randomize
   * @returns Bitfield as a string, that is the same as the thresholds length + 1
   */
  applyDifferentialPrivacy(inValue) {
    if (!this.thresholds || !this.diff_p) {
      return inValue;
    }
    return unaryEncodeDiffPrivacy(
      inValue,
      this.thresholds.length + 1,
      this.diff_p,
      this.diff_q
    );
  }
}

/**
 * Manages relative tile importance
 */
export class TileImportance {
  constructor(tileImportanceMappings) {
    this.mappings = {};
    for (const [formatKey, formatId] of Object.entries(FORMAT)) {
      if (formatKey in tileImportanceMappings) {
        this.mappings[formatId] = tileImportanceMappings[formatKey];
      }
    }
  }

  getRelativeCTRForTile(tileType) {
    return this.mappings[tileType] || 1;
  }

  static fromJSON(json) {
    return new TileImportance(json);
  }
}

/***
 * A simple model for aggregating features
 */

export class FeatureModel {
  /**
   *
   * @param {string} modelId
   * @param {object} dayTimeWeighting Data for day time weighting class
   * @param {object} interestVectorModel Data for interest model
   * @param {object} tileImportance Data for tile importance
   * @param {boolean} rescale Whether to rescale to max value
   * @param {boolean} logScale Whether to apply natural log (ln(x+ 1)) before rescaling
   */
  constructor({
    modelId,
    dayTimeWeighting,
    interestVectorModel,
    tileImportance,
    modelType,
    rescale = false,
    logScale = false,
    normalize = false,
    normalizeL1 = false,
    privateFeatures = [],
  }) {
    this.modelId = modelId;
    this.tileImportance = tileImportance;
    this.dayTimeWeighting = dayTimeWeighting;
    this.interestVectorModel = interestVectorModel;
    this.rescale = rescale;
    this.logScale = logScale;
    this.normalize = normalize;
    this.normalizeL1 = normalizeL1;
    this.modelType = modelType;
    this.privateFeatures = privateFeatures;
  }

  static fromJSON(json) {
    const dayTimeWeighting = DayTimeWeighting.fromJSON(json.day_time_weighting);
    const interestVectorModel = {};
    const tileImportance = TileImportance.fromJSON(json.tile_importance || {});

    for (const [name, featureJson] of Object.entries(json.interest_vector)) {
      interestVectorModel[name] = InterestFeatures.fromJSON(name, featureJson);
    }

    return new FeatureModel({
      dayTimeWeighting,
      tileImportance,
      interestVectorModel,
      normalize: json.normalize,
      normalizeL1: json.normalize_l1,
      rescale: json.rescale,
      logScale: json.log_scale,
      clickScale: json.clickScale,
      modelType: json.model_type,
      privateFeatures: json.private_features ?? null,
    });
  }

  supportsCoarseInterests() {
    return Object.values(this.interestVectorModel).every(
      fm => fm.thresholds && fm.thresholds.length
    );
  }

  supportsCoarsePrivateInterests() {
    return Object.values(this.interestVectorModel).every(
      fm =>
        fm.thresholds &&
        fm.thresholds.length &&
        "diff_p" in fm &&
        "diff_q" in fm
    );
  }

  /**
   * Return date intervals for the query
   */
  getDateIntervals(curTimeMs) {
    return this.dayTimeWeighting.getDateIntervals(curTimeMs);
  }

  /**
   * Computes an interest vector or aggregate based on the model and raw sql inout.
   *
   * @param {object} config
   * @param {Array.<Array.<string|number>>} config.dataForIntervals Raw aggregate output from SQL query. Could be clicks or impressions
   * @param {{[key: string]: number}} config.indexSchema Map of keys to indices in each sub-array in dataForIntervals
   * @param {boolean} [config.applyThresholding=false] Whether to apply thresholds
   * @param {boolean} [config.applyDifferntialPrivacy=false] Whether to apply differential privacy. This will be used for sending to telemetry.
   * @returns
   */
  computeInterestVector({
    dataForIntervals,
    indexSchema,
    applyPostProcessing = false,
    applyThresholding = false,
    applyDifferentialPrivacy = false,
  }) {
    const processedPerTimeInterval = dataForIntervals.map(
      (intervalData, idx) => {
        const intervalRawTotal = {};
        const perPeriodTotals = {};
        intervalData.forEach(aggElement => {
          const feature = aggElement[indexSchema[AggregateResultKeys.FEATURE]];
          let value = aggElement[indexSchema[AggregateResultKeys.VALUE]]; // In the future we could support format here
          dictAdd(intervalRawTotal, feature, value);
        });

        const weight = this.dayTimeWeighting.getRelativeWeight(idx); // Weight for this time interval
        Object.values(this.interestVectorModel).forEach(interestFeature => {
          for (const featureUsed of Object.keys(
            interestFeature.featureWeights
          )) {
            if (featureUsed in intervalRawTotal) {
              dictAdd(
                perPeriodTotals,
                interestFeature.name,
                intervalRawTotal[featureUsed] *
                  weight *
                  interestFeature.featureWeights[featureUsed]
              );
            }
          }
        });
        return perPeriodTotals;
      }
    );

    // Since we are doing linear combinations, it is fine to do the day-time weighting at this step
    let totalResults = {};
    processedPerTimeInterval.forEach(intervalTotals => {
      for (const key of Object.keys(intervalTotals)) {
        dictAdd(totalResults, key, intervalTotals[key]);
      }
    });

    let numClicks = -1;

    // If clicks is a feature, it's handled as special case
    if (SPECIAL_FEATURE_CLICK in totalResults) {
      numClicks = totalResults[SPECIAL_FEATURE_CLICK];
      delete totalResults[SPECIAL_FEATURE_CLICK];
    }

    if (applyPostProcessing) {
      totalResults = this.applyPostProcessing(totalResults);
    }

    if (this.clickScale && numClicks > 0) {
      totalResults = dictApply(totalResults, x => x / numClicks);
    }

    const zeroFilledResult = {};
    // Set non-click or impression values in a way that preserves original key order
    Object.values(this.interestVectorModel).forEach(interestFeature => {
      zeroFilledResult[interestFeature.name] =
        totalResults[interestFeature.name] || 0;
    });
    totalResults = zeroFilledResult;

    if (numClicks >= 0) {
      // Optional
      totalResults[SPECIAL_FEATURE_CLICK] = numClicks;
    }
    if (applyThresholding) {
      this.applyThresholding(totalResults, applyDifferentialPrivacy);
    }
    return totalResults;
  }

  /**
   * Convert float to discrete values, based on threshold parmaters for each feature in the model.
   * Values are modifified in place on provided dictionary.
   *
   * @param {object} valueDict of all values in model
   * @param {boolean} applyDifferentialPrivacy whether to apply differential privacy as well as thresholding.
   */
  applyThresholding(valueDict, applyDifferentialPrivacy = false) {
    for (const key of Object.keys(valueDict)) {
      if (key in this.interestVectorModel) {
        valueDict[key] = this.interestVectorModel[key].applyThresholds(
          valueDict[key],
          applyDifferentialPrivacy
        );
        if (applyDifferentialPrivacy) {
          valueDict[key] = this.interestVectorModel[
            key
          ].applyDifferentialPrivacy(valueDict[key], applyDifferentialPrivacy);
        }
      }
    }
  }

  applyPostProcessing(valueDict) {
    let res = valueDict;
    if (this.logScale) {
      res = dictApply(valueDict, x => Math.log(x + 1));
    }

    if (this.rescale) {
      let divisor = Math.max(...Object.values(res));
      if (divisor <= 1e-6) {
        divisor = 1e-6;
      }
      res = dictApply(res, x => x / divisor);
    }

    if (this.normalizeL1) {
      let magnitude = Object.values(res).reduce((sum, c) => sum + c, 0);
      if (magnitude <= 1e-6) {
        magnitude = 1e-6;
      }
      res = dictApply(res, x => x / magnitude);
    }

    if (this.normalize) {
      let magnitude = Math.sqrt(
        Object.values(res).reduce((sum, c) => sum + c ** 2, 0)
      );
      if (magnitude <= 1e-6) {
        magnitude = 1e-6;
      }
      res = dictApply(res, x => x / magnitude);
    }
    return res;
  }

  /**
   * Computes interest vectors based on click-through rate (CTR) by dividing the click dictionary
   * by the impression dictionary. Applies differential privacy using Laplace noise, and optionally
   * computes coarse (without noise) and coarse-private interest vectors if supported by the model.
   *
   * In all cases model_id is returned.
   *
   * @param {object} params - Function parameters.
   * @param {{[key: string]: number}} params.clickDict - A dictionary of interest keys to click counts.
   * @param {{[key: string]: number}} params.impressionDict - A dictionary of interest keys to impression counts.
   * @param {string} [params.model_id="unknown"] - Identifier for the model used in generating the vectors.
   * @param {boolean} [params.condensePrivateValues=true] - If true, condenses coarse private interest values into an array format.
   *
   * @returns {object} result - An object containing one or more of the following:
   * @returns {object} result.inferredInterest - A dictionary of private inferred interest scores
   * @returns {object} [result.coarseInferredInterests] - A dictionary of thresholded interest scores (non-private), if supported.
   * @returns {object} [result.coarsePrivateInferredInterests] - A dictionary of thresholded interest scores with differential privacy, if supported.
   */
  computeCTRInterestVectors({
    clicks,
    impressions,
    model_id = "unknown",
    condensePrivateValues = true,
  }) {
    let inferredInterests = divideDict(clicks, impressions);

    const originalInterestValues = { ...inferredInterests };

    const resultObject = {
      inferredInterests: { ...inferredInterests, model_id },
    };

    if (this.supportsCoarseInterests()) {
      // always true
      const coarseValues = this.applyPostProcessing({
        ...originalInterestValues,
      });
      this.applyThresholding(coarseValues, false);
      resultObject.coarseInferredInterests = { ...coarseValues, model_id };
    }

    if (this.supportsCoarsePrivateInterests()) {
      let coarsePrivateValues = { ...originalInterestValues };
      if (this.privateFeatures) {
        // filter here for private features
        coarsePrivateValues = Object.fromEntries(
          Object.entries(coarsePrivateValues).filter(([key]) =>
            this.privateFeatures.includes(key)
          )
        );
        this.applyPostProcessing({ ...originalInterestValues });
      }
      this.applyThresholding(coarsePrivateValues, true);

      if (condensePrivateValues) {
        resultObject.coarsePrivateInferredInterests = {
          // Key order preserved in Gecko
          values: Object.values(coarsePrivateValues),
          model_id,
        };
      } else {
        resultObject.coarsePrivateInferredInterests = {
          ...coarsePrivateValues,
          model_id,
        };
      }
    }
    return resultObject;
  }

  /**
   * Computes various types of interest vectors from user interaction data across intervals.
   * Returns standard inferred interests (with Laplace noise), and optionally returns
   * coarse-grained and private-coarse versions depending on model support.
   *
   * @param {object} params - The function parameters.
   * @param {Array<object>} params.dataForIntervals - An array of data points grouped by time intervals (e.g., clicks, impressions).
   * @param {object} params.indexSchema - Schema that defines how interest indices should be computed.
   * @param {string} [params.model_id="unknown"] - Identifier for the model used to produce these vectors.
   * @param {boolean} [params.condensePrivateValues=true] - If true, condenses coarse private interest values into an array format.
   *
   * @returns {object} result - An object containing the computed interest vectors.
   * @returns {object} result.inferredInterests - A dictionary of private inferred interest values, with `model_id`.
   * @returns {object} [result.coarseInferredInterests] - Coarse thresholded (non-private) interest vector, if supported.
   * @returns {object | {values: Array<number>, model_id: string}} [result.coarsePrivateInferredInterests] - Coarse and differentially private interests.
   *           If `condensePrivateValues` is true, returned as an object with a `values` array; otherwise, as a dictionary.
   */
  computeInterestVectors({
    dataForIntervals,
    indexSchema,
    model_id = "unknown",
    condensePrivateValues = true,
  }) {
    const result = {};
    let inferredInterests;
    let coarseInferredInterests;
    let coarsePrivateInferredInterests;

    inferredInterests = this.computeInterestVector({
      dataForIntervals,
      indexSchema,
    });
    result.inferredInterests = { ...inferredInterests };

    if (this.supportsCoarseInterests()) {
      coarseInferredInterests = this.computeInterestVector({
        dataForIntervals,
        indexSchema,
        applyThresholding: true,
      });
      if (coarseInferredInterests) {
        result.coarseInferredInterests = {
          ...coarseInferredInterests,
          model_id,
        };
      }
    }

    if (this.supportsCoarsePrivateInterests()) {
      coarsePrivateInferredInterests = this.computeInterestVector({
        dataForIntervals,
        indexSchema,
        applyThresholding: true,
        applyDifferentialPrivacy: true,
      });
      if (coarsePrivateInferredInterests) {
        if (condensePrivateValues) {
          result.coarsePrivateInferredInterests = {
            // Key order preserved in Gecko
            values: Object.values(coarsePrivateInferredInterests),
            model_id,
          };
        } else {
          result.coarsePrivateInferredInterests = {
            ...coarsePrivateInferredInterests,
            model_id,
          };
        }
      }
    }
    return result;
  }
}
