using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using MarkovGrams.Utilities; using SBRL.Algorithms; namespace MarkovGrams { /// /// An unweighted character-based markov chain. /// public class WeightedMarkovChain { private WeightedRandom wrandom = new WeightedRandom(); /// /// The ngrams that this markov chain currently contains. /// private Dictionary ngrams; /// /// Whether to always start generating a new word from an n-gram that starts with /// an uppercase letter. /// public bool StartOnUppercase = false; /// /// The generation mode to use when running the Markov Chain. /// /// /// The input n-grams must have been generated using the same mode specified here. /// public GenerationMode Mode { get; private set; } = GenerationMode.CharacterLevel; /// /// Creates a new character-based markov chain. /// /// The ngrams to populate the new markov chain with. public WeightedMarkovChain(Dictionary inNgrams, GenerationMode inMode) { ngrams = inNgrams; Mode = inMode; } public WeightedMarkovChain(Dictionary inNgrams, GenerationMode inMode) { ngrams = new Dictionary(); foreach (KeyValuePair ngram in inNgrams) ngrams[ngram.Key] = ngram.Value; Mode = inMode; } /// /// Returns a random ngram that's currently loaded into this WeightedMarkovChain. /// /// A random ngram from this UnweightedMarkovChain's cache of ngrams. public string RandomNgram() { if (wrandom.Count == 0) { if (!StartOnUppercase) wrandom.SetContents(ngrams); else { ConcurrentDictionary filteredNGrams = new ConcurrentDictionary(); Parallel.ForEach(ngrams, (KeyValuePair pair) => { if (!char.IsUpper(pair.Key[0])) return; if (!filteredNGrams.TryAdd(pair.Key, pair.Value)) throw new Exception("Error: Couldn't add to uppercase staging n-gram ConcurrentDictionary!"); }); if (filteredNGrams.Count() == 0) throw new Exception($"Error: No valid starting ngrams were found (StartOnUppercase: {StartOnUppercase})."); wrandom.SetContents(filteredNGrams); } } return wrandom.Next(); } /// /// Generates a new random string from the currently stored ngrams. /// /// /// The length of ngram to generate. /// Note that this is a target, not a fixed value - e.g. passing 2 when the n-gram order is 3 will /// result in a string of length 3. Also, depending on the current ngrams this markov chain contains, /// it may end up being cut short. /// /// A new random string. public string Generate(int length) { return Generate(length, out float noop); } public string Generate(int length, out float choicePointRatio) { string result = RandomNgram(); string lastNgram = result; ConcurrentBag choiceCounts = new ConcurrentBag(); int i = 0; while((Mode == GenerationMode.CharacterLevel ? result.Length : result.CountCharInstances(" ".ToCharArray()) + 1) < length) { wrandom.ClearContents(); // The substring that the next ngram in the chain needs to start with string nextStartsWith = Mode == GenerationMode.CharacterLevel ? lastNgram.Substring(1) : string.Join(" ", lastNgram.Split(' ').Skip(1)); // Get a list of possible n-grams we could choose from next ConcurrentDictionary convNextNgrams = new ConcurrentDictionary(); Parallel.ForEach(ngrams, (KeyValuePair ngramData) => { if (!ngramData.Key.StartsWithFast(nextStartsWith)) return; if (!convNextNgrams.TryAdd(ngramData.Key, ngramData.Value)) throw new Exception("Error: Failed to add to staging ngram concurrent dictionary"); }); choiceCounts.Add(convNextNgrams.Count()); // If there aren't any choices left, we can't exactly keep adding to the new string any more :-( if(convNextNgrams.Count == 0) break; wrandom.SetContents(convNextNgrams); // Pick a random n-gram from the list string nextNgram = wrandom.Next(); // Add the last character from the n-gram to the string we're building if (Mode == GenerationMode.CharacterLevel) result += nextNgram[nextNgram.Length - 1]; else result += ' ' + nextNgram.Substring(nextNgram.LastIndexOf(' ') + 1); lastNgram = nextNgram; i++; } wrandom.ClearContents(); if (choiceCounts.Sum() > 0) choicePointRatio = (float)choiceCounts.Sum() / (float)(i + 1); else choicePointRatio = 0; return result; } } }